diff --git a/src/Simplex/FileTransfer/Agent.hs b/src/Simplex/FileTransfer/Agent.hs index c3723e7a9..ff271240b 100644 --- a/src/Simplex/FileTransfer/Agent.hs +++ b/src/Simplex/FileTransfer/Agent.hs @@ -223,6 +223,7 @@ runXFTPRcvWorker c srv Worker {doWork} = do agentXFTPDownloadChunk c userId digest replica chunkSpec liftIO $ waitUntilForeground c (entityId, complete, progress) <- withStore c $ \db -> runExceptT $ do + liftIO $ lockRcvFileForUpdate db rcvFileId liftIO $ updateRcvFileChunkReceived db (rcvChunkReplicaId replica) rcvChunkId relChunkPath RcvFile {size = FileSize currentSize, chunks, redirect} <- ExceptT $ getRcvFile db rcvFileId let rcvd = receivedSize chunks @@ -413,6 +414,7 @@ runXFTPSndPrepareWorker c Worker {doWork} = do withStore' c $ \db -> updateSndFileStatus db sndFileId SFSEncrypting (digest, chunkSpecsDigests) <- encryptFileForUpload sndFile fsEncPath withStore c $ \db -> do + lockSndFileForUpdate db sndFileId updateSndFileEncrypted db sndFileId digest chunkSpecsDigests getSndFile db sndFileId else pure sndFile @@ -530,6 +532,7 @@ runXFTPSndWorker c srv Worker {doWork} = do agentXFTPUploadChunk c userId chunkDigest replica' chunkSpec' liftIO $ waitUntilForeground c sf@SndFile {sndFileEntityId, prefixPath, chunks} <- withStore c $ \db -> do + lockSndFileForUpdate db sndFileId updateSndChunkReplicaStatus db sndChunkReplicaId SFRSUploaded getSndFile db sndFileId let uploaded = uploadedSize chunks diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 3130e0227..01d95596b 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -1129,7 +1129,8 @@ startJoinInvitation c userId connId sq_ enableNtfs cReqUri pqSup = let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport} case sq_ of Just sq@SndQueue {e2ePubKey = Just _k} -> do - e2eSndParams <- withStore c $ \db -> + e2eSndParams <- withStore c $ \db -> do + lockConnForUpdate db connId getSndRatchet db connId v >>= \case Right r -> pure $ Right $ snd r Left e -> do @@ -1143,6 +1144,7 @@ startJoinInvitation c userId connId sq_ enableNtfs cReqUri pqSup = sndKey_ = snd <$> invLink_ (q, _) <- lift $ newSndQueue userId "" qInfo sndKey_ withStore c $ \db -> runExceptT $ do + liftIO $ lockConnForUpdate db connId e2eSndParams <- createRatchet_ db g maxSupported pqSupport e2eRcvParams sq' <- maybe (ExceptT $ updateNewConnSnd db connId q) pure sq_ pure (cData, sq', e2eSndParams, lnkId_) @@ -1221,7 +1223,8 @@ joinConnSrv c nm userId connId enableNtfs cReqUri@CRContactUri {} cInfo pqSup su AgentConfig {smpClientVRange = vr, smpAgentVRange, e2eEncryptVRange = e2eVR} <- asks config let qUri = SMPQueueUri vr $ (rcvSMPQueueAddress rq) {queueMode = Just QMMessaging} crData = ConnReqUriData SSSimplex smpAgentVRange [qUri] Nothing - e2eRcvParams <- withStore' c $ \db -> + e2eRcvParams <- withStore' c $ \db -> do + lockConnForUpdate db connId getRatchetX3dhKeys db connId >>= \case Right keys -> pure $ CR.mkRcvE2ERatchetParams (maxVersion e2eVR) keys Left e -> do @@ -1937,7 +1940,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} sq@SndQueue {userId, connId, server, withRetryLock2 ri' qLock $ \riState loop -> do liftIO $ waitWhileSuspended c liftIO $ waitForUserNetwork c - resp <- tryError $ case msgType of + resp <- tryAllErrors $ case msgType of AM_CONN_INFO -> sendConfirmation c NRMBackground sq msgBody AM_CONN_INFO_REPLY -> sendConfirmation c NRMBackground sq msgBody _ -> case pendingMsgPrepData_ of @@ -2077,10 +2080,12 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} sq@SndQueue {userId, connId, server, notifyDelMsgs :: InternalId -> AgentErrorType -> UTCTime -> AM () notifyDelMsgs msgId err expireTs = do notifyDel msgId $ MERR (unId msgId) err - msgIds_ <- withStore' c $ \db -> getExpiredSndMessages db connId sq expireTs + msgIds_ <- withStore' c $ \db -> do + msgIds_ <- getExpiredSndMessages db connId sq expireTs + forM_ msgIds_ $ \msgId' -> deleteSndMsgDelivery db connId sq msgId' False `catchAll_` pure () + pure msgIds_ forM_ (L.nonEmpty msgIds_) $ \msgIds -> do notify $ MERRS (L.map unId msgIds) err - withStore' c $ \db -> forM_ msgIds $ \msgId' -> deleteSndMsgDelivery db connId sq msgId' False `catchAll_` pure () atomically $ incSMPServerStat' c userId server sentExpiredErrs (length msgIds_ + 1) delMsg :: InternalId -> AM () delMsg = delMsgKeep False @@ -3005,7 +3010,8 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId throwE e agentClientMsg :: TVar ChaChaDRG -> ByteString -> AM (Maybe (InternalId, MsgMeta, AMessage, CR.RatchetX448)) agentClientMsg g encryptedMsgHash = withStore c $ \db -> runExceptT $ do - rc <- ExceptT $ getRatchet db connId -- ratchet state pre-decryption - required for processing EREADY + liftIO $ lockConnForUpdate db connId + rc <- ExceptT $ getRatchetForUpdate db connId -- ratchet state pre-decryption - required for processing EREADY (agentMsgBody, pqEncryption) <- agentRatchetDecrypt' g db connId rc encAgentMessage liftEither (parse smpP (SEAgentError $ AGENT A_MESSAGE) agentMsgBody) >>= \case agentMsg@(AgentMessage APrivHeader {sndMsgId, prevMsgHash} aMessage) -> do @@ -3240,6 +3246,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId Just sqs' -> do (sq_@SndQueue {sndPrivateKey}, dhPublicKey) <- lift $ newSndQueue userId connId qInfo Nothing sq2 <- withStore c $ \db -> do + lockConnForUpdate db connId liftIO $ mapM_ (deleteConnSndQueue db connId) delSqs addConnSndQueue db connId (sq_ :: NewSndQueue) {primary = True, dbReplaceQueueId = Just dbQueueId} logServer "<--" c srv rId $ "MSG :" <> logSecret' srvMsgId <> " " <> logSecret (senderId queueAddress) @@ -3544,7 +3551,7 @@ agentRatchetEncrypt db cData msg getPaddedLen pqEnc_ currentE2EVersion = do agentRatchetEncryptHeader :: DB.Connection -> ConnData -> (VersionSMPA -> PQSupport -> Int) -> Maybe PQEncryption -> CR.VersionE2E -> ExceptT StoreError IO (CR.MsgEncryptKeyX448, Int, PQEncryption) agentRatchetEncryptHeader db ConnData {connId, connAgentVersion = v, pqSupport} getPaddedLen pqEnc_ currentE2EVersion = do - rc <- ExceptT $ getRatchet db connId + rc <- ExceptT $ getRatchetForUpdate db connId let paddedLen = getPaddedLen v pqSupport (mek, rc') <- withExceptT (SEAgentError . cryptoError) $ CR.rcEncryptHeader rc pqEnc_ currentE2EVersion liftIO $ updateRatchet db connId rc' CR.SMDNoChange @@ -3553,7 +3560,7 @@ agentRatchetEncryptHeader db ConnData {connId, connAgentVersion = v, pqSupport} -- encoded EncAgentMessage -> encoded AgentMessage agentRatchetDecrypt :: TVar ChaChaDRG -> DB.Connection -> ConnId -> ByteString -> ExceptT StoreError IO (ByteString, PQEncryption) agentRatchetDecrypt g db connId encAgentMsg = do - rc <- ExceptT $ getRatchet db connId + rc <- ExceptT $ getRatchetForUpdate db connId agentRatchetDecrypt' g db connId rc encAgentMsg agentRatchetDecrypt' :: TVar ChaChaDRG -> DB.Connection -> ConnId -> CR.RatchetX448 -> ByteString -> ExceptT StoreError IO (ByteString, PQEncryption) diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index f90c4db4f..ebc5410d1 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -2114,16 +2114,17 @@ withWork :: AgentClient -> TMVar () -> (DB.Connection -> IO (Either StoreError ( withWork c doWork = withWork_ c doWork . withStore' c {-# INLINE withWork #-} +-- setting doWork flag to "no work" before getWork rather than after prevents race condition when flag is set to "has work" by another thread after getWork call. withWork_ :: (AnyStoreError e', MonadIO m) => AgentClient -> TMVar () -> ExceptT e m (Either e' (Maybe a)) -> (a -> ExceptT e m ()) -> ExceptT e m () withWork_ c doWork getWork action = - getWork >>= \case - Right (Just r) -> action r - Right Nothing -> noWork - -- worker is stopped here (noWork) because the next iteration is likely to produce the same result + noWork >> getWork >>= \case + Right (Just r) -> hasWork >> action r + Right Nothing -> pure () Left e - | isWorkItemError e -> noWork >> notifyErr (CRITICAL False) e - | otherwise -> notifyErr INTERNAL e + | isWorkItemError e -> notifyErr (CRITICAL False) e -- worker remains stopped here because the next iteration is likely to produce the same result + | otherwise -> hasWork >> notifyErr INTERNAL e where + hasWork = atomically $ hasWorkToDo' doWork noWork = liftIO $ noWorkToDo doWork notifyErr err e = do logError $ "withWork_ error: " <> tshow e @@ -2131,22 +2132,24 @@ withWork_ c doWork getWork action = withWorkItems :: (AnyStoreError e', MonadIO m) => AgentClient -> TMVar () -> ExceptT e m (Either e' [Either e' a]) -> (NonEmpty a -> ExceptT e m ()) -> ExceptT e m () withWorkItems c doWork getWork action = do - getWork >>= \case - Right [] -> noWork + noWork >> getWork >>= \case + Right [] -> pure () Right rs -> do let (errs, items) = partitionEithers rs case L.nonEmpty items of - Just items' -> action items' + Just items' -> hasWork >> action items' Nothing -> do - let criticalErr = find isWorkItemError errs - forM_ criticalErr $ \err -> do - notifyErr (CRITICAL False) err - when (all isWorkItemError errs) noWork + case find isWorkItemError errs of + Nothing -> hasWork + Just err -> do + notifyErr (CRITICAL False) err + unless (all isWorkItemError errs) hasWork forM_ (L.nonEmpty errs) $ notifySub c . ERRS . L.map (\e -> ("", INTERNAL $ show e)) Left e - | isWorkItemError e -> noWork >> notifyErr (CRITICAL False) e - | otherwise -> notifyErr INTERNAL e + | isWorkItemError e -> notifyErr (CRITICAL False) e + | otherwise -> hasWork >> notifyErr INTERNAL e where + hasWork = atomically $ hasWorkToDo' doWork noWork = liftIO $ noWorkToDo doWork notifyErr err e = do logError $ "withWorkItems error: " <> tshow e diff --git a/src/Simplex/Messaging/Agent/Store/AgentStore.hs b/src/Simplex/Messaging/Agent/Store/AgentStore.hs index a4c95af87..0469f09dd 100644 --- a/src/Simplex/Messaging/Agent/Store/AgentStore.hs +++ b/src/Simplex/Messaging/Agent/Store/AgentStore.hs @@ -52,6 +52,7 @@ module Simplex.Messaging.Agent.Store.AgentStore getConnSubs, getDeletedConns, getConnsData, + lockConnForUpdate, setConnDeleted, setConnUserId, setConnAgentVersion, @@ -140,6 +141,7 @@ module Simplex.Messaging.Agent.Store.AgentStore createRatchet, deleteRatchet, getRatchet, + getRatchetForUpdate, getSkippedMsgKeys, updateRatchet, -- Async commands @@ -187,6 +189,7 @@ module Simplex.Messaging.Agent.Store.AgentStore -- Rcv files createRcvFile, createRcvFileRedirect, + lockRcvFileForUpdate, getRcvFile, getRcvFileByEntityId, getRcvFileRedirects, @@ -207,6 +210,7 @@ module Simplex.Messaging.Agent.Store.AgentStore getRcvFilesExpired, -- Snd files createSndFile, + lockSndFileForUpdate, getSndFile, getSndFileByEntityId, getNextSndFileToPrepare, @@ -405,7 +409,7 @@ createNewConn db gVar cData cMode = do -- TODO [certs rcv] store clientServiceId from NewRcvQueue updateNewConnRcv :: DB.Connection -> ConnId -> NewRcvQueue -> SubscriptionMode -> IO (Either StoreError RcvQueue) updateNewConnRcv db connId rq subMode = - getConn db connId $>>= \case + getConnForUpdate db connId $>>= \case (SomeConn _ NewConnection {}) -> updateConn (SomeConn _ RcvConnection {}) -> updateConn -- to allow retries (SomeConn c _) -> pure . Left . SEBadConnType "updateNewConnRcv" $ connType c @@ -415,7 +419,7 @@ updateNewConnRcv db connId rq subMode = updateNewConnSnd :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue) updateNewConnSnd db connId sq = - getConn db connId $>>= \case + getConnForUpdate db connId $>>= \case (SomeConn _ NewConnection {}) -> updateConn (SomeConn c _) -> pure . Left . SEBadConnType "updateNewConnSnd" $ connType c where @@ -449,7 +453,11 @@ checkConfirmedSndQueueExists_ db SndQueue {server, sndId} = maybeFirstRow' False fromOnlyBI $ DB.query db - "SELECT 1 FROM snd_queues WHERE host = ? AND port = ? AND snd_id = ? AND status != ? LIMIT 1" + ( "SELECT 1 FROM snd_queues WHERE host = ? AND port = ? AND snd_id = ? AND status != ? LIMIT 1" +#if defined(dpPostgres) + <> " FOR UPDATE" +#endif + ) (host server, port server, sndId, New) getRcvConn :: DB.Connection -> SMPServer -> SMP.RecipientId -> IO (Either StoreError (RcvQueue, SomeConn)) @@ -488,14 +496,14 @@ deleteConn db waitDeliveryTimeout_ connId = case waitDeliveryTimeout_ of upgradeRcvConnToDuplex :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue) upgradeRcvConnToDuplex db connId sq = - getConn db connId $>>= \case + getConnForUpdate db connId $>>= \case (SomeConn _ RcvConnection {}) -> Right <$> addConnSndQueue_ db connId sq (SomeConn c _) -> pure . Left . SEBadConnType "upgradeRcvConnToDuplex" $ connType c -- TODO [certs rcv] store clientServiceId from NewRcvQueue upgradeSndConnToDuplex :: DB.Connection -> ConnId -> NewRcvQueue -> SubscriptionMode -> IO (Either StoreError RcvQueue) upgradeSndConnToDuplex db connId rq subMode = - getConn db connId >>= \case + getConnForUpdate db connId >>= \case Right (SomeConn _ SndConnection {}) -> Right <$> addConnRcvQueue_ db connId rq subMode Right (SomeConn c _) -> pure . Left . SEBadConnType "upgradeSndConnToDuplex" $ connType c _ -> pure $ Left SEConnNotFound @@ -503,7 +511,7 @@ upgradeSndConnToDuplex db connId rq subMode = -- TODO [certs rcv] store clientServiceId from NewRcvQueue addConnRcvQueue :: DB.Connection -> ConnId -> NewRcvQueue -> SubscriptionMode -> IO (Either StoreError RcvQueue) addConnRcvQueue db connId rq subMode = - getConn db connId >>= \case + getConnForUpdate db connId >>= \case Right (SomeConn _ DuplexConnection {}) -> Right <$> addConnRcvQueue_ db connId rq subMode Right (SomeConn c _) -> pure . Left . SEBadConnType "addConnRcvQueue" $ connType c _ -> pure $ Left SEConnNotFound @@ -515,7 +523,7 @@ addConnRcvQueue_ db connId rq@RcvQueue {server} subMode = do addConnSndQueue :: DB.Connection -> ConnId -> NewSndQueue -> IO (Either StoreError SndQueue) addConnSndQueue db connId sq = - getConn db connId >>= \case + getConnForUpdate db connId >>= \case Right (SomeConn _ DuplexConnection {}) -> Right <$> addConnSndQueue_ db connId sq Right (SomeConn c _) -> pure . Left . SEBadConnType "addConnSndQueue" $ connType c _ -> pure $ Left SEConnNotFound @@ -1048,7 +1056,14 @@ setMsgUserAck :: DB.Connection -> ConnId -> InternalId -> IO (Either StoreError setMsgUserAck db connId agentMsgId = runExceptT $ do (dbRcvId, srvMsgId) <- ExceptT . firstRow id (SEMsgNotFound "setMsgUserAck") $ - DB.query db "SELECT rcv_queue_id, broker_id FROM rcv_messages WHERE conn_id = ? AND internal_id = ?" (connId, agentMsgId) + DB.query + db + ( "SELECT rcv_queue_id, broker_id FROM rcv_messages WHERE conn_id = ? AND internal_id = ?" +#if defined(dbPostgres) + <> " FOR UPDATE" +#endif + ) + (connId, agentMsgId) rq <- ExceptT $ getRcvQueueById db connId dbRcvId liftIO $ DB.execute db "UPDATE rcv_messages SET user_ack = ? WHERE conn_id = ? AND internal_id = ?" (BI True, connId, agentMsgId) pure (rq, srvMsgId) @@ -1120,6 +1135,9 @@ deleteMsgContent db connId msgId = do deleteDeliveredSndMsg :: DB.Connection -> ConnId -> InternalId -> IO () deleteDeliveredSndMsg db connId msgId = do +#if defined(dbPostgres) + _ :: [Only Int] <- DB.query db "SELECT 1 FROM messages WHERE conn_id = ? AND internal_id = ? FOR UPDATE" (connId, msgId) +#endif cnt <- countPendingSndDeliveries_ db connId msgId when (cnt == 0) $ deleteMsg db connId msgId @@ -1138,11 +1156,15 @@ deleteSndMsgDelivery db connId SndQueue {dbQueueId} msgId keepForReceipt = do maybeFirstRow id $ DB.query db - [sql| - SELECT rcpt_status, snd_message_body_id FROM snd_messages - WHERE NOT EXISTS (SELECT 1 FROM snd_message_deliveries WHERE conn_id = ? AND internal_id = ? AND failed = 0) - AND conn_id = ? AND internal_id = ? - |] + ( [sql| + SELECT rcpt_status, snd_message_body_id FROM snd_messages + WHERE NOT EXISTS (SELECT 1 FROM snd_message_deliveries WHERE conn_id = ? AND internal_id = ? AND failed = 0) + AND conn_id = ? AND internal_id = ? + |] +#if defined(dbPostgres) + <> " FOR UPDATE" +#endif + ) (connId, msgId, connId, msgId) deleteMsgAndBody :: (Maybe MsgReceiptStatus, Maybe Int64) -> IO () deleteMsgAndBody (rcptStatus_, sndMsgBodyId_) = do @@ -1151,9 +1173,11 @@ deleteSndMsgDelivery db connId SndQueue {dbQueueId} msgId keepForReceipt = do Just MROk -> deleteMsg _ -> if keepForReceipt then deleteMsgContent else deleteMsg del db connId msgId - forM_ sndMsgBodyId_ $ \bodyId -> - -- Delete message body if it is not used by any snd message. - -- The current snd message is already deleted by deleteMsg or cleared by deleteMsgContent. + forM_ sndMsgBodyId_ $ \bodyId -> do +#if defined(dbPostgres) + -- lock for concurrent deletion of different records in snd_messages pointing to the same record in snd_message_bodies + _ :: [Only Int] <- DB.query db "SELECT 1 FROM snd_message_bodies WHERE snd_message_body_id = ? FOR UPDATE" (Only bodyId) +#endif DB.execute db [sql| @@ -1260,9 +1284,25 @@ deleteRatchet :: DB.Connection -> ConnId -> IO () deleteRatchet db connId = DB.execute db "DELETE FROM ratchets WHERE conn_id = ?" (Only connId) +getRatchetForUpdate :: DB.Connection -> ConnId -> IO (Either StoreError RatchetX448) +getRatchetForUpdate = +#if defined(dbPostgres) + getRatchet_ (ratchetQuery <> " FOR UPDATE") +#else + getRatchet_ ratchetQuery +#endif +{-# INLINE getRatchetForUpdate #-} + getRatchet :: DB.Connection -> ConnId -> IO (Either StoreError RatchetX448) -getRatchet db connId = - firstRow' ratchet SERatchetNotFound $ DB.query db "SELECT ratchet_state FROM ratchets WHERE conn_id = ?" (Only connId) +getRatchet = getRatchet_ ratchetQuery +{-# INLINE getRatchet #-} + +ratchetQuery :: Query +ratchetQuery = "SELECT ratchet_state FROM ratchets WHERE conn_id = ?" + +getRatchet_ :: Query -> DB.Connection -> ConnId -> IO (Either StoreError RatchetX448) +getRatchet_ q db connId = + firstRow' ratchet SERatchetNotFound $ DB.query db q (Only connId) where ratchet = maybe (Left SERatchetNotFound) Right . fromOnly @@ -1963,13 +2003,15 @@ instance (ToField a, ToField b, ToField c, ToField d, ToField e, ToField f, -- | Creates a new server, if it doesn't exist, and returns the passed key hash if it is different from stored. createServer_ :: DB.Connection -> SMPServer -> IO (Maybe C.KeyHash) -createServer_ db newSrv@ProtocolServer {host, port, keyHash} = - getServerKeyHash_ db newSrv >>= \case - Right keyHash_ -> pure keyHash_ - Left _ -> insertNewServer_ $> Nothing +createServer_ db newSrv@ProtocolServer {host, port, keyHash} = do + r <- insertNewServer_ + if null r + then getServerKeyHash_ db newSrv >>= either E.throwIO pure + else pure Nothing where + insertNewServer_ :: IO [Only Int] insertNewServer_ = - DB.execute db "INSERT INTO servers (host, port, key_hash) VALUES (?,?,?)" (host, port, keyHash) + DB.query db "INSERT INTO servers (host, port, key_hash) VALUES (?,?,?) ON CONFLICT (host, port) DO NOTHING RETURNING 1" (host, port, keyHash) -- | Returns the passed server key hash if it is different from the stored one, or the error if the server does not exist. getServerKeyHash_ :: DB.Connection -> SMPServer -> IO (Either StoreError (Maybe C.KeyHash)) @@ -2166,23 +2208,27 @@ getConnIds :: DB.Connection -> IO [ConnId] getConnIds db = map fromOnly <$> DB.query_ db "SELECT conn_id FROM connections WHERE deleted = 0" getConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn) -getConn = getAnyConn False +getConn = getAnyConn False False {-# INLINE getConn #-} +getConnForUpdate :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn) +getConnForUpdate = getAnyConn False True +{-# INLINE getConnForUpdate #-} + getDeletedConn :: DB.Connection -> ConnId -> IO (Either StoreError SomeConn) -getDeletedConn = getAnyConn True +getDeletedConn = getAnyConn True False {-# INLINE getDeletedConn #-} -getAnyConn :: Bool -> DB.Connection -> ConnId -> IO (Either StoreError SomeConn) +getAnyConn :: Bool -> Bool -> DB.Connection -> ConnId -> IO (Either StoreError SomeConn) getAnyConn = getAnyConn_ getRcvQueuesByConnId_ getSndQueuesByConnId_ {-# INLINE getAnyConn #-} getAnyConn_ :: (DB.Connection -> ConnId -> IO (Maybe (NonEmpty rq))) -> (DB.Connection -> ConnId -> IO (Maybe (NonEmpty sq))) -> - (Bool -> DB.Connection -> ConnId -> IO (Either StoreError (SomeConn' rq sq))) -getAnyConn_ getRQs getSQs deleted' db connId = - getConnData deleted' db connId >>= \case + (Bool -> Bool -> DB.Connection -> ConnId -> IO (Either StoreError (SomeConn' rq sq))) +getAnyConn_ getRQs getSQs deleted' forUpdate db connId = + getConnData deleted' forUpdate db connId >>= \case Just (cData, cMode) -> do rQ <- getRQs db connId sQ <- getSQs db connId @@ -2281,28 +2327,39 @@ getAnyConns_ :: (DB.Connection -> ConnId -> IO (Maybe (NonEmpty rq))) -> (DB.Connection -> ConnId -> IO (Maybe (NonEmpty sq))) -> (Bool -> DB.Connection -> [ConnId] -> IO [Either StoreError (SomeConn' rq sq)]) -getAnyConns_ getRQs getSQs deleted' db connIds = forM connIds $ E.handle handleDBError . getAnyConn_ getRQs getSQs deleted' db +getAnyConns_ getRQs getSQs deleted' db connIds = forM connIds $ E.handle handleDBError . getAnyConn_ getRQs getSQs deleted' False db getConnsData :: DB.Connection -> [ConnId] -> IO [Either StoreError (Maybe (ConnData, ConnectionMode))] -getConnsData db connIds = forM connIds $ E.handle handleDBError . fmap Right . getConnData False db +getConnsData db connIds = forM connIds $ E.handle handleDBError . fmap Right . getConnData False False db handleDBError :: E.SomeException -> IO (Either StoreError a) handleDBError = pure . Left . SEInternal . bshow #endif -getConnData :: Bool -> DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode)) -getConnData deleted' db connId' = +getConnData :: Bool -> Bool -> DB.Connection -> ConnId -> IO (Maybe (ConnData, ConnectionMode)) +getConnData deleted' forUpdate db connId' = maybeFirstRow rowToConnData $ DB.query db - [sql| - SELECT user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, - last_external_snd_msg_id, deleted, ratchet_sync_state, pq_support - FROM connections - WHERE conn_id = ? AND deleted = ? - |] + ( [sql| + SELECT user_id, conn_id, conn_mode, smp_agent_version, enable_ntfs, + last_external_snd_msg_id, deleted, ratchet_sync_state, pq_support + FROM connections + WHERE conn_id = ? AND deleted = ? + |] +#if defined(dbPostgres) + <> (if forUpdate then " FOR UPDATE" else "") +#endif + ) (connId', BI deleted') +lockConnForUpdate :: DB.Connection -> ConnId -> IO () +lockConnForUpdate db connId = do +#if defined(dbPostgres) + _ :: [Only Int] <- DB.query db "SELECT 1 FROM connections WHERE conn_id = ? FOR UPDATE" (Only connId) +#endif + pure () + rowToConnData :: (UserId, ConnId, ConnectionMode, VersionSMPA, Maybe BoolInt, PrevExternalSndId, BoolInt, RatchetSyncState, PQSupport) -> (ConnData, ConnectionMode) rowToConnData (userId, connId, cMode, connAgentVersion, enableNtfs_, lastExternalSndId, BI deleted, ratchetSyncState, pqSupport) = (ConnData {userId, connId, connAgentVersion, enableNtfs = maybe True unBI enableNtfs_, lastExternalSndId, deleted, ratchetSyncState, pqSupport}, cMode) @@ -2347,7 +2404,11 @@ checkRatchetKeyHashExists db connId hash = maybeFirstRow' False fromOnlyBI $ DB.query db - "SELECT 1 FROM processed_ratchet_key_hashes WHERE conn_id = ? AND hash = ? LIMIT 1" + ( "SELECT 1 FROM processed_ratchet_key_hashes WHERE conn_id = ? AND hash = ? LIMIT 1" +#if defined(dbPostgres) + <> " FOR UPDATE" +#endif + ) (connId, Binary hash) deleteRatchetKeyHashesExpired :: DB.Connection -> NominalDiffTime -> IO () @@ -2471,11 +2532,15 @@ retrieveLastIdsAndHashRcv_ dbConn connId = do [(lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash)] <- DB.query dbConn - [sql| - SELECT last_internal_msg_id, last_internal_rcv_msg_id, last_external_snd_msg_id, last_rcv_msg_hash - FROM connections - WHERE conn_id = ? - |] + ( [sql| + SELECT last_internal_msg_id, last_internal_rcv_msg_id, last_external_snd_msg_id, last_rcv_msg_hash + FROM connections + WHERE conn_id = ? + |] +#if defined(dbPostgres) + <> " FOR UPDATE" +#endif + ) (Only connId) return (lastInternalId, lastInternalRcvId, lastExternalSndId, lastRcvHash) @@ -2542,11 +2607,15 @@ retrieveLastIdsAndHashSnd_ dbConn connId = do firstRow id SEConnNotFound $ DB.query dbConn - [sql| - SELECT last_internal_msg_id, last_internal_snd_msg_id, last_snd_msg_hash - FROM connections - WHERE conn_id = ? - |] + ( [sql| + SELECT last_internal_msg_id, last_internal_snd_msg_id, last_snd_msg_hash + FROM connections + WHERE conn_id = ? + |] +#if defined(dbPostgres) + <> " FOR UPDATE" +#endif + ) (Only connId) updateLastIdsSnd_ :: DB.Connection -> ConnId -> InternalId -> InternalSndId -> IO () @@ -2636,19 +2705,19 @@ ntfSubAndSMPAction (NSANtf action) = (Just action, Nothing) ntfSubAndSMPAction (NSASMP action) = (Nothing, Just action) createXFTPServer_ :: DB.Connection -> XFTPServer -> IO Int64 -createXFTPServer_ db newSrv@ProtocolServer {host, port, keyHash} = - getXFTPServerId_ db newSrv >>= \case - Right srvId -> pure srvId - Left _ -> insertNewServer_ - where - insertNewServer_ = do - DB.execute db "INSERT INTO xftp_servers (xftp_host, xftp_port, xftp_key_hash) VALUES (?,?,?)" (host, port, keyHash) - insertedRowId db - -getXFTPServerId_ :: DB.Connection -> XFTPServer -> IO (Either StoreError Int64) -getXFTPServerId_ db ProtocolServer {host, port, keyHash} = do - firstRow fromOnly SEXFTPServerNotFound $ - DB.query db "SELECT xftp_server_id FROM xftp_servers WHERE xftp_host = ? AND xftp_port = ? AND xftp_key_hash = ?" (host, port, keyHash) +createXFTPServer_ db ProtocolServer {host, port, keyHash} = do + Only serverId : _ <- + DB.query + db + [sql| + INSERT INTO xftp_servers (xftp_host, xftp_port, xftp_key_hash) + VALUES (?, ?, ?) + ON CONFLICT (xftp_host, xftp_port, xftp_key_hash) + DO UPDATE SET xftp_host = EXCLUDED.xftp_host + RETURNING xftp_server_id + |] + (host, port, keyHash) + pure serverId createRcvFile :: DB.Connection -> TVar ChaChaDRG -> UserId -> FileDescription 'FRecipient -> FilePath -> FilePath -> CryptoFile -> Bool -> IO (Either StoreError RcvFileId) createRcvFile db gVar userId fd@FileDescription {chunks} prefixPath tmpPath file approvedRelays = runExceptT $ do @@ -2728,6 +2797,13 @@ getRcvFileRedirects db rcvFileId = do redirects <- fromOnly <$$> DB.query db "SELECT rcv_file_id FROM rcv_files WHERE redirect_id = ?" (Only rcvFileId) fmap catMaybes . forM redirects $ getRcvFile db >=> either (const $ pure Nothing) (pure . Just) +lockRcvFileForUpdate :: DB.Connection -> DBRcvFileId -> IO () +lockRcvFileForUpdate db rcvFileId = do +#if defined(dbPostgres) + _ :: [Only Int] <- DB.query db "SELECT 1 FROM rcv_files WHERE rcv_file_id = ? FOR UPDATE" (Only rcvFileId) +#endif + pure () + getRcvFile :: DB.Connection -> DBRcvFileId -> IO (Either StoreError RcvFile) getRcvFile db rcvFileId = runExceptT $ do f@RcvFile {rcvFileEntityId, userId, tmpPath} <- ExceptT getFile @@ -2739,11 +2815,15 @@ getRcvFile db rcvFileId = runExceptT $ do firstRow toFile SEFileNotFound $ DB.query db - [sql| - SELECT rcv_file_entity_id, user_id, size, digest, key, nonce, chunk_size, prefix_path, tmp_path, save_path, save_file_key, save_file_nonce, status, deleted, redirect_id, redirect_entity_id, redirect_size, redirect_digest - FROM rcv_files - WHERE rcv_file_id = ? - |] + ( [sql| + SELECT rcv_file_entity_id, user_id, size, digest, key, nonce, chunk_size, prefix_path, tmp_path, save_path, save_file_key, save_file_nonce, status, deleted, redirect_id, redirect_entity_id, redirect_size, redirect_digest + FROM rcv_files + WHERE rcv_file_id = ? + |] +#if defined(dbPostgres) + <> " FOR UPDATE" +#endif + ) (Only rcvFileId) where toFile :: (RcvFileId, UserId, FileSize Int64, FileDigest, C.SbKey, C.CbNonce, FileSize Word32, FilePath, Maybe FilePath) :. (FilePath, Maybe C.SbKey, Maybe C.CbNonce, RcvFileStatus, BoolInt, Maybe DBRcvFileId, Maybe RcvFileId, Maybe (FileSize Int64), Maybe FileDigest) -> RcvFile @@ -3004,6 +3084,13 @@ getSndFileIdByEntityId_ db sndFileEntityId = firstRow fromOnly SEFileNotFound $ DB.query db "SELECT snd_file_id FROM snd_files WHERE snd_file_entity_id = ?" (Only (Binary sndFileEntityId)) +lockSndFileForUpdate :: DB.Connection -> DBSndFileId -> IO () +lockSndFileForUpdate db sndFileId = do +#if defined(dbPostgres) + _ :: [Only Int] <- DB.query db "SELECT 1 FROM snd_files WHERE snd_file_id = ? FOR UPDATE" (Only sndFileId) +#endif + pure () + getSndFile :: DB.Connection -> DBSndFileId -> IO (Either StoreError SndFile) getSndFile db sndFileId = runExceptT $ do f@SndFile {sndFileEntityId, userId, numRecipients, prefixPath} <- ExceptT getFile @@ -3015,11 +3102,15 @@ getSndFile db sndFileId = runExceptT $ do firstRow toFile SEFileNotFound $ DB.query db - [sql| - SELECT snd_file_entity_id, user_id, path, src_file_key, src_file_nonce, num_recipients, digest, prefix_path, key, nonce, status, deleted, redirect_size, redirect_digest - FROM snd_files - WHERE snd_file_id = ? - |] + ( [sql| + SELECT snd_file_entity_id, user_id, path, src_file_key, src_file_nonce, num_recipients, digest, prefix_path, key, nonce, status, deleted, redirect_size, redirect_digest + FROM snd_files + WHERE snd_file_id = ? + |] +#if defined(dbPostgres) + <> " FOR UPDATE" +#endif + ) (Only sndFileId) where toFile :: (SndFileId, UserId, FilePath, Maybe C.SbKey, Maybe C.CbNonce, Int, Maybe FileDigest, Maybe FilePath, C.SbKey, C.CbNonce) :. (SndFileStatus, BoolInt, Maybe (FileSize Int64), Maybe FileDigest) -> SndFile diff --git a/src/Simplex/Messaging/Agent/Store/Postgres/DB.hs b/src/Simplex/Messaging/Agent/Store/Postgres/DB.hs index 7bd1103e8..deec20e2f 100644 --- a/src/Simplex/Messaging/Agent/Store/Postgres/DB.hs +++ b/src/Simplex/Messaging/Agent/Store/Postgres/DB.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} module Simplex.Messaging.Agent.Store.Postgres.DB @@ -12,29 +13,32 @@ module Simplex.Messaging.Agent.Store.Postgres.DB execute, execute_, executeMany, - PSQL.query, - PSQL.query_, + query, + query_, blobFieldDecoder, fromTextField_, ) where +import qualified Control.Exception as E import Control.Monad (void) +import qualified Data.ByteString as B import Data.ByteString.Char8 (ByteString) import Data.Int (Int64) import Data.Text (Text) import Data.Text.Encoding (decodeUtf8) import Data.Typeable (Typeable) import Data.Word (Word16, Word32) -import Database.PostgreSQL.Simple (ResultError (..)) +import Database.PostgreSQL.Simple (Connection, ResultError (..), SqlError (..), FromRow, ToRow) import qualified Database.PostgreSQL.Simple as PSQL import Database.PostgreSQL.Simple.FromField (Field (..), FieldParser, FromField (..), returnError) import Database.PostgreSQL.Simple.ToField (ToField (..)) import Database.PostgreSQL.Simple.TypeInfo.Static (textOid, varcharOid) +import Database.PostgreSQL.Simple.Types (Query (..)) newtype BoolInt = BI {unBI :: Bool} -type SQLError = PSQL.SqlError +type SQLError = SqlError instance FromField BoolInt where fromField field dat = BI . (/= (0 :: Int)) <$> fromField field dat @@ -44,18 +48,30 @@ instance ToField BoolInt where toField (BI b) = toField ((if b then 1 else 0) :: Int) {-# INLINE toField #-} -execute :: PSQL.ToRow q => PSQL.Connection -> PSQL.Query -> q -> IO () -execute db q qs = void $ PSQL.execute db q qs +execute :: ToRow q => PSQL.Connection -> Query -> q -> IO () +execute db q qs = void $ PSQL.execute db q qs `E.catch` addSql q {-# INLINE execute #-} -execute_ :: PSQL.Connection -> PSQL.Query -> IO () -execute_ db q = void $ PSQL.execute_ db q +execute_ :: PSQL.Connection -> Query -> IO () +execute_ db q = void $ PSQL.execute_ db q `E.catch` addSql q {-# INLINE execute_ #-} -executeMany :: PSQL.ToRow q => PSQL.Connection -> PSQL.Query -> [q] -> IO () -executeMany db q qs = void $ PSQL.executeMany db q qs +executeMany :: ToRow q => PSQL.Connection -> Query -> [q] -> IO () +executeMany db q qs = void $ PSQL.executeMany db q qs `E.catch` addSql q {-# INLINE executeMany #-} +query :: (ToRow q, FromRow r) => PSQL.Connection -> Query -> q -> IO [r] +query db q qs = PSQL.query db q qs `E.catch` addSql q +{-# INLINE query #-} + +query_ :: FromRow r => Connection -> Query -> IO [r] +query_ db q = PSQL.query_ db q `E.catch` addSql q +{-# INLINE query_ #-} + +addSql :: Query -> SqlError -> IO r +addSql q e@SqlError {sqlErrorHint = hint} = + E.throwIO e {sqlErrorHint = if B.null hint then fromQuery q else hint <> ", " <> fromQuery q} + -- orphan instances -- used in FileSize diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 1ceb675ab..fd86c8e48 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DuplicateRecordFields #-} @@ -219,6 +220,9 @@ pattern SENT msgId = A.SENT msgId Nothing pattern Rcvd :: AgentMsgId -> AEvent 'AEConn pattern Rcvd agentMsgId <- RCVD MsgMeta {integrity = MsgOk} [MsgReceipt {agentMsgId, msgRcptStatus = MROk}] +pattern Rcvd' :: AgentMsgId -> AgentMsgId -> AEvent 'AEConn +pattern Rcvd' aMsgId rcvdMsgId <- RCVD MsgMeta {integrity = MsgOk, recipient = (aMsgId, _)} [MsgReceipt {agentMsgId = rcvdMsgId, msgRcptStatus = MROk}] + pattern INV :: AConnectionRequestUri -> AEvent 'AEConn pattern INV cReq = A.INV cReq Nothing @@ -331,8 +335,8 @@ functionalAPITests ps = do describe "Duplex connection - delivery stress test" $ do describe "one way (50)" $ testMatrix2Stress ps $ runAgentClientStressTestOneWay 50 xdescribe "one way (1000)" $ testMatrix2Stress ps $ runAgentClientStressTestOneWay 1000 - describe "two way concurrently (50)" $ testMatrix2Stress ps $ runAgentClientStressTestConc 25 - xdescribe "two way concurrently (1000)" $ testMatrix2Stress ps $ runAgentClientStressTestConc 500 + describe "two way concurrently (50)" $ testMatrix2Stress ps $ runAgentClientStressTestConc 50 + xdescribe "two way concurrently (1000)" $ testMatrix2Stress ps $ runAgentClientStressTestConc 1000 describe "Establishing duplex connection, different PQ settings" $ do testPQMatrix2 ps $ runAgentClientTestPQ False True describe "Establishing duplex connection v2, different Ratchet versions" $ @@ -782,36 +786,64 @@ runAgentClientStressTestOneWay n pqSupport sqSecured viaProxy alice bob baseId = runAgentClientStressTestConc :: HasCallStack => Int64 -> PQSupport -> SndQueueSecured -> Bool -> AgentClient -> AgentClient -> AgentMsgId -> IO () runAgentClientStressTestConc n pqSupport sqSecured viaProxy alice bob baseId = runRight_ $ do - let pqEnc = PQEncryption $ supportPQ pqSupport (aliceId, bobId) <- makeConnection_ pqSupport sqSecured alice bob - let proxySrv = if viaProxy then Just testSMPServer else Nothing - message i = "message " <> bshow i - loop a bId mIdVar i = do - when (i <= n) $ do - mId <- msgId <$> A.sendMessage a bId pqEnc SMP.noMsgFlags (message i) - liftIO $ mId >= i `shouldBe` True - let getEvent = do - get a >>= \case - ("", c, A.SENT _ srv) -> liftIO $ c == bId && srv == proxySrv `shouldBe` True - ("", c, QCONT) -> do - liftIO $ c == bId `shouldBe` True - getEvent - ("", c, Msg' mId pq msg) -> do - -- tests that mId increases - liftIO $ (mId >) <$> atomically (swapTVar mIdVar mId) `shouldReturn` True - liftIO $ c == bId && pq == pqEnc && ("message " `B.isPrefixOf` msg) `shouldBe` True - ackMessage a bId mId Nothing - r -> liftIO $ expectationFailure $ "wrong message: " <> show r - getEvent amId <- newTVarIO 0 bmId <- newTVarIO 0 - concurrently_ - (forM_ ([1 .. n * 2] :: [Int64]) $ loop alice bobId amId) - (forM_ ([1 .. n * 2] :: [Int64]) $ loop bob aliceId bmId) + let n2 = n `div` 2 + mapConcurrently_ id + ( [ send alice bobId [1 .. n2], + send alice bobId [n2 + 1 .. n], + send bob aliceId [1 .. n2], + send bob aliceId [n2 + 1 .. n], + receive alice bobId amId (n, n, n, 2 * n), + receive bob aliceId bmId (n, n, n, 2 * n) + ] :: [ExceptT AgentErrorType IO ()] + ) liftIO $ noMessagesIngoreQCONT alice "nothing else should be delivered to alice" liftIO $ noMessagesIngoreQCONT bob "nothing else should be delivered to bob" where msgId = subtract baseId . fst + pqEnc = PQEncryption $ supportPQ pqSupport + proxySrv = if viaProxy then Just testSMPServer else Nothing + message i = "message " <> bshow i + send :: AgentClient -> ConnId -> [Int64] -> ExceptT AgentErrorType IO () + send a bId = mapM_ $ \i -> void $ A.sendMessage a bId pqEnc SMP.noMsgFlags (message i) + receive :: AgentClient -> ConnId -> TVar AgentMsgId -> (Int64, Int64, Int64, Int64) -> ExceptT AgentErrorType IO () + receive a bId mIdVar acc' = loop acc' >> liftIO drain + where + drain = + timeout 50000 (get a) + >>= mapM_ (\case ("", _, QCONT) -> drain; r -> expectationFailure $ "unexpected: " <> show r) + loop (0, 0, 0, 0) = pure () + loop acc@(!s, !m, !r, !o) = + timeout 3000000 (get a) >>= \case + Nothing -> error $ "timeout " <> show acc + Just evt -> case evt of + ("", c, A.SENT mId srv) -> do + liftIO $ c == bId && srv == proxySrv `shouldBe` True + unless (s > 0) $ error "unexpected SENT" + loop (s - 1, m, r, o) + ("", c, QCONT) -> do + liftIO $ c == bId `shouldBe` True + loop (s, m, r, o) + ("", c, Msg' mId pq msg) -> do + -- tests that mId increases + liftIO $ (mId >) <$> atomically (swapTVar mIdVar mId) `shouldReturn` True + liftIO $ c == bId && pq == pqEnc && ("message " `B.isPrefixOf` msg) `shouldBe` True + ackMessageAsync a "123" bId mId (Just "") + unless (m > 0) $ error "unexpected MSG" + loop (s, m - 1, r, o) + ("", c, Rcvd' mId rcvdMsgId) -> do + liftIO $ (mId >) <$> atomically (swapTVar mIdVar mId) `shouldReturn` True + liftIO $ c == bId `shouldBe` True + ackMessageAsync a "123" bId mId Nothing + unless (r > 0) $ error "unexpected RCVD" + loop (s, m, r - 1, o) + ("123", c, OK) -> do + liftIO $ c == bId `shouldBe` True + unless (o > 0) $ error "unexpected OK" + loop (s, m, r, o - 1) + _ -> liftIO $ expectationFailure $ "unexpected: " <> show r testEnablePQEncryption :: HasCallStack => IO () testEnablePQEncryption = @@ -999,10 +1031,10 @@ noMessages_ :: Bool -> HasCallStack => AgentClient -> String -> Expectation noMessages_ ingoreQCONT c err = tryGet `shouldReturn` () where tryGet = - 10000 `timeout` get c >>= \case + 50000 `timeout` get c >>= \case Just (_, _, QCONT) | ingoreQCONT -> noMessages_ ingoreQCONT c err Just msg -> error $ err <> ": " <> show msg - _ -> return () + Nothing -> return () testRejectContactRequest :: HasCallStack => IO () testRejectContactRequest = @@ -3679,7 +3711,15 @@ getSMPAgentClient' clientId cfg' initServers dbPath = do #if defined(dbPostgres) createStore :: String -> IO (Either MigrationError DBStore) -createStore schema = createAgentStore (DBOpts testDBConnstr (B.pack schema) 1 True) (MigrationConfig MCError Nothing) +createStore schema = createAgentStore dbOpts $ MigrationConfig MCError Nothing + where + dbOpts = + DBOpts + { connstr = testDBConnstr, + schema = B.pack schema, + poolSize = 10, + createSchema = True + } insertUser :: DBStore -> IO () insertUser st = withTransaction st (`DB.execute_` "INSERT INTO users DEFAULT VALUES") diff --git a/tests/NtfClient.hs b/tests/NtfClient.hs index 30b648401..d7b72b766 100644 --- a/tests/NtfClient.hs +++ b/tests/NtfClient.hs @@ -87,7 +87,7 @@ ntfTestStoreDBOpts = DBOpts { connstr = ntfTestServerDBConnstr, schema = "ntf_server", - poolSize = 3, + poolSize = 10, createSchema = True } diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 3c1ac0150..9a88865e3 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -93,7 +93,7 @@ testStoreDBOpts = DBOpts { connstr = testServerDBConnstr, schema = "smp_server", - poolSize = 3, + poolSize = 10, createSchema = True }