diff --git a/Sources/Helpers/EventEmitter.swift b/Sources/Helpers/EventEmitter.swift index 6d360a20..3d189d9f 100644 --- a/Sources/Helpers/EventEmitter.swift +++ b/Sources/Helpers/EventEmitter.swift @@ -8,7 +8,7 @@ import ConcurrencyExtras import Foundation -public final class ObservationToken: Sendable { +public final class ObservationToken: Sendable, Hashable { let _onCancel = LockIsolated((@Sendable () -> Void)?.none) package init(_ onCancel: (@Sendable () -> Void)? = nil) { @@ -34,6 +34,20 @@ public final class ObservationToken: Sendable { deinit { cancel() } + + public static func == (lhs: ObservationToken, rhs: ObservationToken) -> Bool { + ObjectIdentifier(lhs) == ObjectIdentifier(rhs) + } + + public func hash(into hasher: inout Hasher) { + hasher.combine(ObjectIdentifier(self)) + } +} + +extension ObservationToken { + public func store(in set: inout Set) { + set.insert(self) + } } package final class EventEmitter: Sendable { diff --git a/Sources/Realtime/V2/RealtimeChannelV2.swift b/Sources/Realtime/V2/RealtimeChannelV2.swift index 284ff3b4..7ed0b25e 100644 --- a/Sources/Realtime/V2/RealtimeChannelV2.swift +++ b/Sources/Realtime/V2/RealtimeChannelV2.swift @@ -77,6 +77,17 @@ public final class RealtimeChannelV2: Sendable { statusEventEmitter.stream() } + /// Listen for connection status changes. + /// - Parameter listener: Closure that will be called when connection status changes. + /// - Returns: An observation handle that can be used to stop listening. + /// + /// - Note: Use ``statusChange`` if you prefer to use Async/Await. + public func onStatusChange( + _ listener: @escaping @Sendable (Status) -> Void + ) -> ObservationToken { + statusEventEmitter.attach(listener) + } + init( topic: String, config: RealtimeChannelConfig, diff --git a/Tests/IntegrationTests/RealtimeIntegrationTests.swift b/Tests/IntegrationTests/RealtimeIntegrationTests.swift index e768b91c..2e86d860 100644 --- a/Tests/IntegrationTests/RealtimeIntegrationTests.swift +++ b/Tests/IntegrationTests/RealtimeIntegrationTests.swift @@ -37,126 +37,130 @@ final class RealtimeIntegrationTests: XCTestCase { ) func testBroadcast() async throws { - let expectation = expectation(description: "receivedBroadcastMessages") - expectation.expectedFulfillmentCount = 3 + try await withMainSerialExecutor { + let expectation = expectation(description: "receivedBroadcastMessages") + expectation.expectedFulfillmentCount = 3 - let channel = realtime.channel("integration") { - $0.broadcast.receiveOwnBroadcasts = true - } + let channel = realtime.channel("integration") { + $0.broadcast.receiveOwnBroadcasts = true + } - let receivedMessages = LockIsolated<[JSONObject]>([]) + let receivedMessages = LockIsolated<[JSONObject]>([]) - Task { - for await message in channel.broadcastStream(event: "test") { - receivedMessages.withValue { - $0.append(message) + Task { + for await message in channel.broadcastStream(event: "test") { + receivedMessages.withValue { + $0.append(message) + } + expectation.fulfill() } - expectation.fulfill() } - } - await Task.megaYield() + await Task.yield() - await channel.subscribe() + await channel.subscribe() - struct Message: Codable { - var value: Int - } + struct Message: Codable { + var value: Int + } - try await channel.broadcast(event: "test", message: Message(value: 1)) - try await channel.broadcast(event: "test", message: Message(value: 2)) - try await channel.broadcast(event: "test", message: ["value": 3, "another_value": 42]) + try await channel.broadcast(event: "test", message: Message(value: 1)) + try await channel.broadcast(event: "test", message: Message(value: 2)) + try await channel.broadcast(event: "test", message: ["value": 3, "another_value": 42]) - await fulfillment(of: [expectation], timeout: 0.5) + await fulfillment(of: [expectation], timeout: 0.5) - XCTAssertNoDifference( - receivedMessages.value, - [ + XCTAssertNoDifference( + receivedMessages.value, [ - "event": "test", - "payload": [ - "value": 1, + [ + "event": "test", + "payload": [ + "value": 1, + ], + "type": "broadcast", ], - "type": "broadcast", - ], - [ - "event": "test", - "payload": [ - "value": 2, + [ + "event": "test", + "payload": [ + "value": 2, + ], + "type": "broadcast", ], - "type": "broadcast", - ], - [ - "event": "test", - "payload": [ - "value": 3, - "another_value": 42, + [ + "event": "test", + "payload": [ + "value": 3, + "another_value": 42, + ], + "type": "broadcast", ], - "type": "broadcast", - ], - ] - ) + ] + ) - await channel.unsubscribe() + await channel.unsubscribe() + } } func testPresence() async throws { - let channel = realtime.channel("integration") { - $0.broadcast.receiveOwnBroadcasts = true - } + try await withMainSerialExecutor { + let channel = realtime.channel("integration") { + $0.broadcast.receiveOwnBroadcasts = true + } - let expectation = expectation(description: "presenceChange") - expectation.expectedFulfillmentCount = 4 + let expectation = expectation(description: "presenceChange") + expectation.expectedFulfillmentCount = 4 - let receivedPresenceChanges = LockIsolated<[any PresenceAction]>([]) + let receivedPresenceChanges = LockIsolated<[any PresenceAction]>([]) - Task { - for await presence in channel.presenceChange() { - receivedPresenceChanges.withValue { - $0.append(presence) + Task { + for await presence in channel.presenceChange() { + receivedPresenceChanges.withValue { + $0.append(presence) + } + expectation.fulfill() } - expectation.fulfill() } - } - - await Task.megaYield() - await channel.subscribe() + await Task.yield() - struct UserState: Codable, Equatable { - let email: String - } + await channel.subscribe() - try await channel.track(UserState(email: "test@supabase.com")) - try await channel.track(["email": "test2@supabase.com"]) - - await channel.untrack() + struct UserState: Codable, Equatable { + let email: String + } - await fulfillment(of: [expectation], timeout: 0.5) + try await channel.track(UserState(email: "test@supabase.com")) + try await channel.track(["email": "test2@supabase.com"]) - let joins = try receivedPresenceChanges.value.map { try $0.decodeJoins(as: UserState.self) } - let leaves = try receivedPresenceChanges.value.map { try $0.decodeLeaves(as: UserState.self) } - XCTAssertNoDifference( - joins, - [ - [], // This is the first PRESENCE_STATE event. - [UserState(email: "test@supabase.com")], - [UserState(email: "test2@supabase.com")], - [], - ] - ) + await channel.untrack() - XCTAssertNoDifference( - leaves, - [ - [], // This is the first PRESENCE_STATE event. - [], - [UserState(email: "test@supabase.com")], - [UserState(email: "test2@supabase.com")], - ] - ) + await fulfillment(of: [expectation], timeout: 0.5) - await channel.unsubscribe() + let joins = try receivedPresenceChanges.value.map { try $0.decodeJoins(as: UserState.self) } + let leaves = try receivedPresenceChanges.value.map { try $0.decodeLeaves(as: UserState.self) } + XCTAssertNoDifference( + joins, + [ + [], // This is the first PRESENCE_STATE event. + [UserState(email: "test@supabase.com")], + [UserState(email: "test2@supabase.com")], + [], + ] + ) + + XCTAssertNoDifference( + leaves, + [ + [], // This is the first PRESENCE_STATE event. + [], + [UserState(email: "test@supabase.com")], + [UserState(email: "test2@supabase.com")], + ] + ) + + await channel.unsubscribe() + } } // FIXME: Test getting stuck @@ -179,7 +183,7 @@ final class RealtimeIntegrationTests: XCTestCase { // await channel.postgresChange(AnyAction.self, schema: "public").prefix(3).collect() // } // -// await Task.megaYield() +// await Task.yield() // await channel.subscribe() // // struct Entry: Codable, Equatable { diff --git a/Tests/RealtimeTests/MockWebSocketClient.swift b/Tests/RealtimeTests/MockWebSocketClient.swift index a6431f77..7c80d60b 100644 --- a/Tests/RealtimeTests/MockWebSocketClient.swift +++ b/Tests/RealtimeTests/MockWebSocketClient.swift @@ -15,42 +15,81 @@ import XCTestDynamicOverlay #endif final class MockWebSocketClient: WebSocketClient { - let sentMessages = LockIsolated<[RealtimeMessageV2]>([]) + struct MutableState { + var receiveContinuation: AsyncThrowingStream.Continuation? + var sentMessages: [RealtimeMessageV2] = [] + var onCallback: ((RealtimeMessageV2) -> RealtimeMessageV2?)? + var connectContinuation: AsyncStream.Continuation? + + var sendMessageBuffer: [RealtimeMessageV2] = [] + var connectionStatusBuffer: [ConnectionStatus] = [] + } + + private let mutableState = LockIsolated(MutableState()) + + var sentMessages: [RealtimeMessageV2] { + mutableState.sentMessages + } + func send(_ message: RealtimeMessageV2) async throws { - sentMessages.withValue { - $0.append(message) - } + mutableState.withValue { + $0.sentMessages.append(message) - if let callback = onCallback.value, let response = callback(message) { - mockReceive(response) + if let callback = $0.onCallback, let response = callback(message) { + mockReceive(response) + } } } - private let receiveContinuation = - LockIsolated.Continuation?>(nil) func mockReceive(_ message: RealtimeMessageV2) { - receiveContinuation.value?.yield(message) + mutableState.withValue { + if let continuation = $0.receiveContinuation { + continuation.yield(message) + } else { + $0.sendMessageBuffer.append(message) + } + } } - private let onCallback = LockIsolated<((RealtimeMessageV2) -> RealtimeMessageV2?)?>(nil) func on(_ callback: @escaping (RealtimeMessageV2) -> RealtimeMessageV2?) { - onCallback.setValue(callback) + mutableState.withValue { + $0.onCallback = callback + } } func receive() -> AsyncThrowingStream { let (stream, continuation) = AsyncThrowingStream.makeStream() - receiveContinuation.setValue(continuation) + mutableState.withValue { + $0.receiveContinuation = continuation + + while !$0.sendMessageBuffer.isEmpty { + let message = $0.sendMessageBuffer.removeFirst() + $0.receiveContinuation?.yield(message) + } + } return stream } - private let connectContinuation = LockIsolated.Continuation?>(nil) func mockConnect(_ status: ConnectionStatus) { - connectContinuation.value?.yield(status) + mutableState.withValue { + if let continuation = $0.connectContinuation { + continuation.yield(status) + } else { + $0.connectionStatusBuffer.append(status) + } + } } func connect() -> AsyncStream { let (stream, continuation) = AsyncStream.makeStream() - connectContinuation.setValue(continuation) + mutableState.withValue { + $0.connectContinuation = continuation + + while !$0.connectionStatusBuffer.isEmpty { + let status = $0.connectionStatusBuffer.removeFirst() + $0.connectContinuation?.yield(status) + } + } return stream } diff --git a/Tests/RealtimeTests/RealtimeTests.swift b/Tests/RealtimeTests/RealtimeTests.swift index 5ef22e6a..428545aa 100644 --- a/Tests/RealtimeTests/RealtimeTests.swift +++ b/Tests/RealtimeTests/RealtimeTests.swift @@ -42,36 +42,53 @@ final class RealtimeTests: XCTestCase { } func testBehavior() async throws { - try await withTimeout(interval: 2) { [self] in - let channel = sut.channel("public:messages") - _ = channel.postgresChange(InsertAction.self, table: "messages") - _ = channel.postgresChange(UpdateAction.self, table: "messages") - _ = channel.postgresChange(DeleteAction.self, table: "messages") + let channel = sut.channel("public:messages") + var subscriptions: Set = [] - let statusChange = sut.statusChange + channel.onPostgresChange(InsertAction.self, table: "messages") { _ in + } + .store(in: &subscriptions) - await connectSocketAndWait() + channel.onPostgresChange(UpdateAction.self, table: "messages") { _ in + } + .store(in: &subscriptions) - let status = await statusChange.prefix(3).collect() - XCTAssertEqual(status, [.disconnected, .connecting, .connected]) + channel.onPostgresChange(DeleteAction.self, table: "messages") { _ in + } + .store(in: &subscriptions) - let messageTask = sut.mutableState.messageTask - XCTAssertNotNil(messageTask) + let socketStatuses = LockIsolated([RealtimeClientV2.Status]()) - let heartbeatTask = sut.mutableState.heartbeatTask - XCTAssertNotNil(heartbeatTask) + sut.onStatusChange { status in + socketStatuses.withValue { $0.append(status) } + } + .store(in: &subscriptions) - let subscription = Task { - await channel.subscribe() - } - await Task.megaYield() - ws.mockReceive(.messagesSubscribed) + await connectSocketAndWait() - // Wait until channel subscribed - await subscription.value + XCTAssertEqual(socketStatuses.value, [.disconnected, .connecting, .connected]) - XCTAssertNoDifference(ws.sentMessages.value, [.subscribeToMessages(ref: "1", joinRef: "1")]) + let messageTask = sut.mutableState.messageTask + XCTAssertNotNil(messageTask) + + let heartbeatTask = sut.mutableState.heartbeatTask + XCTAssertNotNil(heartbeatTask) + + let channelStatuses = LockIsolated([RealtimeChannelV2.Status]()) + channel.onStatusChange { status in + channelStatuses.withValue { + $0.append(status) + } } + .store(in: &subscriptions) + + ws.mockReceive(.messagesSubscribed) + await channel.subscribe() + + XCTAssertNoDifference( + ws.sentMessages, + [.subscribeToMessages(ref: "1", joinRef: "1")] + ) } func testSubscribeTimeout() async throws { @@ -105,16 +122,11 @@ final class RealtimeTests: XCTestCase { } await connectSocketAndWait() - - Task { - await channel.subscribe() - } - - await Task.megaYield() + await channel.subscribe() try? await Task.sleep(nanoseconds: NSEC_PER_SEC * 2) - let joinSentMessages = ws.sentMessages.value.filter { $0.event == "phx_join" } + let joinSentMessages = ws.sentMessages.filter { $0.event == "phx_join" } let expectedMessages = try [ RealtimeMessageV2( @@ -150,90 +162,81 @@ final class RealtimeTests: XCTestCase { } func testHeartbeat() async throws { - try await withTimeout(interval: 4) { [self] in - let expectation = expectation(description: "heartbeat") - expectation.expectedFulfillmentCount = 2 - - ws.on { message in - if message.event == "heartbeat" { - expectation.fulfill() - return RealtimeMessageV2( - joinRef: message.joinRef, - ref: message.ref, - topic: "phoenix", - event: "phx_reply", - payload: [ - "response": [:], - "status": "ok", - ] - ) - } + let expectation = expectation(description: "heartbeat") + expectation.expectedFulfillmentCount = 2 - return nil + ws.on { message in + if message.event == "heartbeat" { + expectation.fulfill() + return RealtimeMessageV2( + joinRef: message.joinRef, + ref: message.ref, + topic: "phoenix", + event: "phx_reply", + payload: [ + "response": [:], + "status": "ok", + ] + ) } - await connectSocketAndWait() - - await fulfillment(of: [expectation], timeout: 3) + return nil } + + await connectSocketAndWait() + + await fulfillment(of: [expectation], timeout: 3) } func testHeartbeat_whenNoResponse_shouldReconnect() async throws { - try await withTimeout(interval: 6) { [self] in - let sentHeartbeatExpectation = expectation(description: "sentHeartbeat") + let sentHeartbeatExpectation = expectation(description: "sentHeartbeat") - ws.on { - if $0.event == "heartbeat" { - sentHeartbeatExpectation.fulfill() - } - - return nil + ws.on { + if $0.event == "heartbeat" { + sentHeartbeatExpectation.fulfill() } - let statuses = LockIsolated<[RealtimeClientV2.Status]>([]) + return nil + } + + let statuses = LockIsolated<[RealtimeClientV2.Status]>([]) - Task { - for await status in sut.statusChange { - statuses.withValue { - $0.append(status) - } + Task { + for await status in sut.statusChange { + statuses.withValue { + $0.append(status) } } - await Task.megaYield() - await connectSocketAndWait() - - await fulfillment(of: [sentHeartbeatExpectation], timeout: 2) - - let pendingHeartbeatRef = sut.mutableState.pendingHeartbeatRef - XCTAssertNotNil(pendingHeartbeatRef) - - // Wait until next heartbeat - try await Task.sleep(nanoseconds: NSEC_PER_SEC * 2) - - // Wait for reconnect delay - try await Task.sleep(nanoseconds: NSEC_PER_SEC * 1) - - XCTAssertEqual( - statuses.value, - [ - .disconnected, - .connecting, - .connected, - .disconnected, - .connecting, - ] - ) } + await Task.yield() + await connectSocketAndWait() + + await fulfillment(of: [sentHeartbeatExpectation], timeout: 2) + + let pendingHeartbeatRef = sut.mutableState.pendingHeartbeatRef + XCTAssertNotNil(pendingHeartbeatRef) + + // Wait until next heartbeat + try await Task.sleep(nanoseconds: NSEC_PER_SEC * 2) + + // Wait for reconnect delay + try await Task.sleep(nanoseconds: NSEC_PER_SEC * 1) + + XCTAssertEqual( + statuses.value, + [ + .disconnected, + .connecting, + .connected, + .disconnected, + .connecting, + ] + ) } private func connectSocketAndWait() async { - let connection = Task { - await sut.connect() - } - await Task.megaYield() - ws.mockConnect(.connected) - await connection.value + await sut.connect() } }