diff --git a/Documentation/guides/basic-concepts/Voice/VoiceModule.cs b/Documentation/guides/basic-concepts/Voice/VoiceModule.cs index 58139a000..321b3468a 100644 --- a/Documentation/guides/basic-concepts/Voice/VoiceModule.cs +++ b/Documentation/guides/basic-concepts/Voice/VoiceModule.cs @@ -145,7 +145,7 @@ public async Task EchoAsync() voiceClient.VoiceReceive += args => { // Pass current user voice directly to the output to create echo - if (voiceClient.Cache.Users.TryGetValue(args.Ssrc, out var voiceUserId) && voiceUserId == userId) + if (voiceClient.Cache.SsrcUsers.TryGetValue(args.Ssrc, out var voiceUserId) && voiceUserId == userId) outStream.Write(args.Frame); return default; }; diff --git a/Hosting/NetCord.Hosting.Services/ComponentInteractions/ComponentInteractionServiceData.cs b/Hosting/NetCord.Hosting.Services/ComponentInteractions/ComponentInteractionServiceData.cs index 2fa53a44a..277a549c6 100644 --- a/Hosting/NetCord.Hosting.Services/ComponentInteractions/ComponentInteractionServiceData.cs +++ b/Hosting/NetCord.Hosting.Services/ComponentInteractions/ComponentInteractionServiceData.cs @@ -1,4 +1,5 @@ using NetCord.Services.ComponentInteractions; namespace NetCord.Hosting.Services.ComponentInteractions; + internal record ComponentInteractionServiceData(IComponentInteractionService Service, IComponentInteractionsBuilder Builder); diff --git a/NetCord/AssemblyInfo.cs b/NetCord/AssemblyInfo.cs new file mode 100644 index 000000000..7450626b2 --- /dev/null +++ b/NetCord/AssemblyInfo.cs @@ -0,0 +1,3 @@ +using System.Runtime.CompilerServices; + +[assembly: DisableRuntimeMarshalling] diff --git a/NetCord/Components/PremiumButton.cs b/NetCord/Components/PremiumButton.cs index 6860a6e9f..2661a72a8 100644 --- a/NetCord/Components/PremiumButton.cs +++ b/NetCord/Components/PremiumButton.cs @@ -1,6 +1,7 @@ using NetCord.JsonModels; namespace NetCord; + public class PremiumButton(JsonComponent jsonModel) : IButton, IJsonModel { JsonComponent IJsonModel.JsonModel => jsonModel; diff --git a/NetCord/Gateway/GatewayClient.cs b/NetCord/Gateway/GatewayClient.cs index 77855002f..c66962020 100644 --- a/NetCord/Gateway/GatewayClient.cs +++ b/NetCord/Gateway/GatewayClient.cs @@ -2,6 +2,7 @@ using NetCord.Gateway.Compression; using NetCord.Gateway.JsonModels; +using NetCord.Gateway.WebSockets; using NetCord.Logging; using WebSocketCloseStatus = System.Net.WebSockets.WebSocketCloseStatus; @@ -904,7 +905,7 @@ private ValueTask SendIdentifyAsync(ConnectionState connectionState, PresencePro Intents = _intents, }).Serialize(Serialization.Default.GatewayPayloadPropertiesGatewayIdentifyProperties); _latencyTimer.Start(); - return SendConnectionPayloadAsync(connectionState, serializedPayload, _internalPayloadProperties, cancellationToken); + return SendConnectionPayloadAsync(connectionState, serializedPayload, _internalTextPayloadProperties, cancellationToken); } /// @@ -949,17 +950,17 @@ private ValueTask TryResumeAsync(ConnectionState connectionState, string session { var serializedPayload = new GatewayPayloadProperties(GatewayOpcode.Resume, new(Token.RawToken, sessionId, sequenceNumber)).Serialize(Serialization.Default.GatewayPayloadPropertiesGatewayResumeProperties); _latencyTimer.Start(); - return SendConnectionPayloadAsync(connectionState, serializedPayload, _internalPayloadProperties, cancellationToken); + return SendConnectionPayloadAsync(connectionState, serializedPayload, _internalTextPayloadProperties, cancellationToken); } private protected override ValueTask HeartbeatAsync(ConnectionState connectionState, CancellationToken cancellationToken = default) { var serializedPayload = new GatewayPayloadProperties(GatewayOpcode.Heartbeat, SequenceNumber).Serialize(Serialization.Default.GatewayPayloadPropertiesInt32); _latencyTimer.Start(); - return SendConnectionPayloadAsync(connectionState, serializedPayload, _internalPayloadProperties, cancellationToken); + return SendConnectionPayloadAsync(connectionState, serializedPayload, _internalTextPayloadProperties, cancellationToken); } - private protected override ValueTask ProcessPayloadAsync(State state, ConnectionState connectionState, ReadOnlySpan payload) + private protected override ValueTask ProcessPayloadAsync(State state, ConnectionState connectionState, WebSocketMessageType messageType, ReadOnlySpan payload) { var jsonPayload = JsonSerializer.Deserialize(_compression.Decompress(payload), Serialization.Default.JsonGatewayPayload)!; return HandlePayloadAsync(state, connectionState, jsonPayload); diff --git a/NetCord/Gateway/Voice/BinaryModels/BinaryVoicePayload.cs b/NetCord/Gateway/Voice/BinaryModels/BinaryVoicePayload.cs new file mode 100644 index 000000000..02518149b --- /dev/null +++ b/NetCord/Gateway/Voice/BinaryModels/BinaryVoicePayload.cs @@ -0,0 +1,14 @@ +using System.Buffers.Binary; + +namespace NetCord.Gateway.Voice.BinaryModels; + +internal readonly ref struct BinaryVoicePayload(ReadOnlySpan payload) +{ + private readonly ReadOnlySpan _payload = payload; + + public ushort SequencyNumber => BinaryPrimitives.ReadUInt16BigEndian(_payload); + + public VoiceOpcode Opcode => (VoiceOpcode)_payload[2]; + + public ReadOnlySpan Data => _payload[3..]; +} diff --git a/NetCord/Gateway/Voice/ConcurrentVoiceClientCache.cs b/NetCord/Gateway/Voice/ConcurrentVoiceClientCache.cs index 3e398c6a5..2d98414a2 100644 --- a/NetCord/Gateway/Voice/ConcurrentVoiceClientCache.cs +++ b/NetCord/Gateway/Voice/ConcurrentVoiceClientCache.cs @@ -1,4 +1,5 @@ -using System.Collections.Concurrent; +using System.Collections; +using System.Collections.Concurrent; using NetCord.Gateway.Voice.JsonModels; @@ -39,34 +40,39 @@ public sealed class ConcurrentVoiceClientCache : IVoiceClientCache { internal ConcurrentVoiceClientCache() { - _ssrcs = new(); - _users = new(); + _users = []; + _userSsrcs = []; + _ssrcUsers = []; } internal ConcurrentVoiceClientCache(JsonVoiceClientCache jsonModel) { _ssrc = jsonModel.Ssrc; - _ssrcs = new(jsonModel.Ssrcs); _users = new(jsonModel.Users); + _userSsrcs = new(jsonModel.UserSsrcs); + _ssrcUsers = new(jsonModel.SsrcUsers); } public uint Ssrc => _ssrc; - public IReadOnlyDictionary Ssrcs => _ssrcs; - public IReadOnlyDictionary Users => _users; + public IReadOnlySet Users => _users; + public IReadOnlyDictionary UserSsrcs => _userSsrcs; + public IReadOnlyDictionary SsrcUsers => _ssrcUsers; #pragma warning disable IDE0032 // Use auto property private uint _ssrc; #pragma warning restore IDE0032 // Use auto property - private readonly ConcurrentDictionary _ssrcs; - private readonly ConcurrentDictionary _users; + private readonly ConcurrentHashSet _users; + private readonly ConcurrentDictionary _userSsrcs; + private readonly ConcurrentDictionary _ssrcUsers; public JsonVoiceClientCache ToJsonModel() { return new() { Ssrc = _ssrc, - Ssrcs = _ssrcs.ToDictionary(), - Users = _users.ToDictionary(), + Users = _users.ToArray(), + UserSsrcs = _userSsrcs.ToArray().ToDictionary(), + SsrcUsers = _ssrcUsers.ToArray().ToDictionary(), }; } @@ -77,18 +83,28 @@ public IVoiceClientCache CacheCurrentSsrc(uint ssrc) return this; } - public IVoiceClientCache CacheUser(ulong userId, uint ssrc) + public IVoiceClientCache CacheUsers(IReadOnlyList userId) { - _ssrcs[userId] = ssrc; - _users[ssrc] = userId; + int count = userId.Count; + for (int i = 0; i < count; i++) + _users.Add(userId[i]); + + return this; + } + + public IVoiceClientCache CacheUserSsrc(ulong userId, uint ssrc) + { + _userSsrcs[userId] = ssrc; + _ssrcUsers[ssrc] = userId; return this; } public IVoiceClientCache RemoveUser(ulong userId) { - if (_ssrcs.TryRemove(userId, out var ssrc)) - _users.TryRemove(ssrc, out _); + _users.Remove(userId); + if (_userSsrcs.TryRemove(userId, out var ssrc)) + _ssrcUsers.TryRemove(ssrc, out _); return this; } @@ -103,4 +119,77 @@ public IReadOnlyDictionary CreateDictionary public void Dispose() { } + + private class ConcurrentHashSet : IReadOnlySet where T : notnull + { + private readonly ConcurrentDictionary _storage; + + public ConcurrentHashSet() + { + _storage = []; + } + + public ConcurrentHashSet(IEnumerable collection) + { + _storage = new(collection.Select(item => new KeyValuePair(item, 0))); + } + + public T[] ToArray() => [.. _storage.Keys]; + + private HashSet HashSet => [.. _storage.Keys]; + + public int Count => _storage.Count; + + public bool Add(T item) + { + return _storage.TryAdd(item, 0); + } + + public bool Remove(T item) + { + return _storage.TryRemove(item, out _); + } + + public bool Contains(T item) + { + return _storage.ContainsKey(item); + } + + public IEnumerator GetEnumerator() + { + return _storage.Select(p => p.Key).GetEnumerator(); + } + + public bool IsProperSubsetOf(IEnumerable other) + { + return HashSet.IsProperSubsetOf(other); + } + + public bool IsProperSupersetOf(IEnumerable other) + { + return HashSet.IsProperSupersetOf(other); + } + + public bool IsSubsetOf(IEnumerable other) + { + return HashSet.IsSubsetOf(other); + } + + public bool IsSupersetOf(IEnumerable other) + { + return HashSet.IsSupersetOf(other); + } + + public bool Overlaps(IEnumerable other) + { + return HashSet.Overlaps(other); + } + + public bool SetEquals(IEnumerable other) + { + return HashSet.SetEquals(other); + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } } diff --git a/NetCord/Gateway/Voice/Dave.cs b/NetCord/Gateway/Voice/Dave.cs new file mode 100644 index 000000000..8f06dab41 --- /dev/null +++ b/NetCord/Gateway/Voice/Dave.cs @@ -0,0 +1,285 @@ +using System.Runtime.InteropServices; + +namespace NetCord.Gateway.Voice; + +public class DaveEncryptorException(string? message) : Exception(message) +{ + internal DaveEncryptorException(Dave.EncryptorResultCode result) : this($"Dave encryptor error: {result}") + { + } +} + +public class DaveDecryptorException(string? message) : Exception(message) +{ + internal DaveDecryptorException(Dave.DecryptorResultCode result) : this($"Dave decryptor error: {result}") + { + } +} + +internal static unsafe partial class Dave +{ + private const string DllName = "dave"; + + public const int InitTransitionId = 0; + + public const int DisabledVersion = 0; + + public class SessionHandle() : SafeHandle(0, true) + { + public override bool IsInvalid => handle is 0; + + protected override bool ReleaseHandle() + { + SessionDestroy(handle); + return true; + } + } + + public class CommitResultHandle() : SafeHandle(0, true) + { + public override bool IsInvalid => handle is 0; + + protected override bool ReleaseHandle() + { + CommitResultDestroy(handle); + return true; + } + } + + public class WelcomeResultHandle() : SafeHandle(0, true) + { + public override bool IsInvalid => handle is 0; + + protected override bool ReleaseHandle() + { + WelcomeResultDestroy(handle); + return true; + } + } + + public class KeyRatchetHandle() : SafeHandle(0, true) + { + public override bool IsInvalid => handle is 0; + + protected override bool ReleaseHandle() + { + KeyRatchetDestroy(handle); + return true; + } + } + + public class EncryptorHandle() : SafeHandle(0, true) + { + public override bool IsInvalid => handle is 0; + + protected override bool ReleaseHandle() + { + EncryptorDestroy(handle); + return true; + } + } + + public class DecryptorHandle() : SafeHandle(0, true) + { + public override bool IsInvalid => handle is 0; + + protected override bool ReleaseHandle() + { + DecryptorDestroy(handle); + return true; + } + } + + public enum CodecType + { + Unknown = 0, + Opus = 1, + Vp8 = 2, + Vp9 = 3, + H264 = 4, + H265 = 5, + Av1 = 6, + } + + public enum MediaType + { + Audio = 0, + Video = 1, + } + + public enum EncryptorResultCode + { + Success = 0, + EncryptionFailure = 1, + } + + public enum DecryptorResultCode + { + Success = 0, + DecryptionFailure = 1, + MissingKeyRatchet = 2, + InvalidNonce = 3, + MissingCryptor = 4, + } + + public enum LoggingSeverity + { + Verbose = 0, + Info = 1, + Warning = 2, + Error = 3, + None = 4, + } + +#pragma warning disable CS0649 + public struct EncryptorStats + { + public ulong PassthroughCount; + public ulong EncryptSuccessCount; + public ulong EncryptFailureCount; + public ulong EncryptDuration; + public ulong EncryptAttempts; + public ulong EncryptMaxAttempts; + public ulong EncryptMissingKeyCount; + } + + public struct DecryptorStats + { + public ulong PassthroughCount; + public ulong DecryptSuccessCount; + public ulong DecryptFailureCount; + public ulong DecryptDuration; + public ulong DecryptAttempts; + public ulong DecryptMissingKeyCount; + public ulong DecryptInvalidNonceCount; + } +#pragma warning restore CS0649 + + [LibraryImport(DllName, EntryPoint = "daveMaxSupportedProtocolVersion")] + public static partial ushort MaxSupportedProtocolVersion(); + + [LibraryImport(DllName, EntryPoint = "daveSessionCreate")] + public static partial SessionHandle SessionCreate(void* context, ReadOnlySpan authSessionId, delegate* mlsFailureCallback, void* userData); + + [LibraryImport(DllName, EntryPoint = "daveSessionDestroy")] + public static partial void SessionDestroy(nint session); + + [LibraryImport(DllName, EntryPoint = "daveSessionInit")] + public static partial void SessionInit(SessionHandle session, ushort version, ulong groupId, ReadOnlySpan selfUserId); + + [LibraryImport(DllName, EntryPoint = "daveSessionReset")] + public static partial void SessionReset(SessionHandle session); + + [LibraryImport(DllName, EntryPoint = "daveSessionSetProtocolVersion")] + public static partial void SessionSetProtocolVersion(SessionHandle session, ushort version); + + [LibraryImport(DllName, EntryPoint = "daveSessionGetProtocolVersion")] + public static partial ushort SessionGetProtocolVersion(SessionHandle session); + + [LibraryImport(DllName, EntryPoint = "daveSessionGetLastEpochAuthenticator")] + public static partial void SessionGetLastEpochAuthenticator(SessionHandle session, byte** authenticator, nuint* length); + + [LibraryImport(DllName, EntryPoint = "daveSessionSetExternalSender")] + public static partial void SessionSetExternalSender(SessionHandle session, byte* externalSender, nuint length); + + [LibraryImport(DllName, EntryPoint = "daveSessionProcessProposals")] + public static partial void SessionProcessProposals(SessionHandle session, byte* proposals, nuint length, ReadOnlySpan recognizedUserIds, nuint recognizedUserIdsLength, out byte* commitWelcomeBytes, out nuint commitWelcomeBytesLength); + + [LibraryImport(DllName, EntryPoint = "daveSessionProcessCommit")] + public static partial CommitResultHandle SessionProcessCommit(SessionHandle session, ReadOnlySpan commit, nuint length); + + [LibraryImport(DllName, EntryPoint = "daveSessionProcessWelcome")] + public static partial WelcomeResultHandle SessionProcessWelcome(SessionHandle session, ReadOnlySpan welcome, nuint length, ReadOnlySpan recognizedUserIds, nuint recognizedUserIdsLength); + + [LibraryImport(DllName, EntryPoint = "daveSessionGetMarshalledKeyPackage")] + public static partial void SessionGetMarshalledKeyPackage(SessionHandle session, out byte* keyPackage, out nuint length); + + [LibraryImport(DllName, EntryPoint = "daveSessionGetKeyRatchet", StringMarshalling = StringMarshalling.Utf8)] + public static partial KeyRatchetHandle SessionGetKeyRatchet(SessionHandle session, ReadOnlySpan userId); + + [LibraryImport(DllName, EntryPoint = "daveSessionGetPairwiseFingerprint", StringMarshalling = StringMarshalling.Utf8)] + public static partial void SessionGetPairwiseFingerprint(SessionHandle session, ushort version, ReadOnlySpan userId, delegate* pairwiseFingerprintCallback, void* userData); + + [LibraryImport(DllName, EntryPoint = "daveKeyRatchetDestroy")] + public static partial void KeyRatchetDestroy(nint keyRatchet); + + [LibraryImport(DllName, EntryPoint = "daveCommitResultIsFailed")] + [return: MarshalAs(UnmanagedType.U1)] + public static partial bool CommitResultIsFailed(CommitResultHandle commitResultHandle); + + [LibraryImport(DllName, EntryPoint = "daveCommitResultIsIgnored")] + [return: MarshalAs(UnmanagedType.U1)] + public static partial bool CommitResultIsIgnored(CommitResultHandle commitResultHandle); + + [LibraryImport(DllName, EntryPoint = "daveCommitResultGetRosterMemberIds")] + public static partial void CommitResultGetRosterMemberIds(CommitResultHandle commitResultHandle, ulong** rosterIds, nuint* rosterIdsLength); + + [LibraryImport(DllName, EntryPoint = "daveCommitResultGetRosterMemberSignature")] + public static partial void CommitResultGetRosterMemberSignature(CommitResultHandle commitResultHandle, ulong rosterId, byte** signature, nuint* signatureLength); + + [LibraryImport(DllName, EntryPoint = "daveCommitResultDestroy")] + public static partial void CommitResultDestroy(nint commitResultHandle); + + [LibraryImport(DllName, EntryPoint = "daveWelcomeResultGetRosterMemberIds")] + public static partial void WelcomeResultGetRosterMemberIds(WelcomeResultHandle welcomeResultHandle, ulong** rosterIds, nuint* rosterIdsLength); + + [LibraryImport(DllName, EntryPoint = "daveWelcomeResultGetRosterMemberSignature")] + public static partial void WelcomeResultGetRosterMemberSignature(WelcomeResultHandle welcomeResultHandle, ulong rosterId, byte** signature, nuint* signatureLength); + + [LibraryImport(DllName, EntryPoint = "daveWelcomeResultDestroy")] + public static partial void WelcomeResultDestroy(nint welcomeResultHandle); + + [LibraryImport(DllName, EntryPoint = "daveEncryptorCreate")] + public static partial EncryptorHandle EncryptorCreate(); + + [LibraryImport(DllName, EntryPoint = "daveEncryptorDestroy")] + public static partial void EncryptorDestroy(nint encryptor); + + [LibraryImport(DllName, EntryPoint = "daveEncryptorSetKeyRatchet")] + public static partial void EncryptorSetKeyRatchet(EncryptorHandle encryptor, KeyRatchetHandle keyRatchet); + + [LibraryImport(DllName, EntryPoint = "daveEncryptorSetPassthroughMode")] + public static partial void EncryptorSetPassthroughMode(EncryptorHandle encryptor, [MarshalAs(UnmanagedType.U1)] bool passthroughMode); + + [LibraryImport(DllName, EntryPoint = "daveEncryptorAssignSsrcToCodec")] + public static partial void EncryptorAssignSsrcToCodec(EncryptorHandle encryptor, uint ssrc, CodecType codecType); + + [LibraryImport(DllName, EntryPoint = "daveEncryptorGetProtocolVersion")] + public static partial ushort EncryptorGetProtocolVersion(EncryptorHandle encryptor); + + [LibraryImport(DllName, EntryPoint = "daveEncryptorGetMaxCiphertextByteSize")] + public static partial nuint EncryptorGetMaxCiphertextByteSize(EncryptorHandle encryptor, MediaType mediaType, nuint frameSize); + + [LibraryImport(DllName, EntryPoint = "daveEncryptorEncrypt")] + public static partial EncryptorResultCode EncryptorEncrypt(EncryptorHandle encryptor, MediaType mediaType, uint ssrc, byte* frame, nuint frameLength, byte* encryptedFrame, nuint encryptedFrameCapacity, out nuint bytesWritten); + + [LibraryImport(DllName, EntryPoint = "daveEncryptorSetProtocolVersionChangedCallback")] + public static partial void EncryptorSetProtocolVersionChangedCallback(EncryptorHandle encryptor, delegate* encryptorProtocolVersionChangedCallback, void* userData); + + [LibraryImport(DllName, EntryPoint = "daveEncryptorGetStats")] + public static partial void EncryptorGetStats(EncryptorHandle encryptor, MediaType mediaType, out EncryptorStats stats); + + [LibraryImport(DllName, EntryPoint = "daveDecryptorCreate")] + public static partial DecryptorHandle DecryptorCreate(); + + [LibraryImport(DllName, EntryPoint = "daveDecryptorDestroy")] + public static partial void DecryptorDestroy(nint decryptor); + + [LibraryImport(DllName, EntryPoint = "daveDecryptorTransitionToKeyRatchet")] + public static partial void DecryptorTransitionToKeyRatchet(DecryptorHandle decryptor, KeyRatchetHandle keyRatchet); + + [LibraryImport(DllName, EntryPoint = "daveDecryptorTransitionToPassthroughMode")] + public static partial void DecryptorTransitionToPassthroughMode(DecryptorHandle decryptor, [MarshalAs(UnmanagedType.U1)] bool passthroughMode); + + [LibraryImport(DllName, EntryPoint = "daveDecryptorDecrypt")] + public static partial DecryptorResultCode DecryptorDecrypt(DecryptorHandle decryptor, MediaType mediaType, byte* encryptedFrame, nuint encryptedFrameLength, byte* frame, nuint frameCapacity, out nuint bytesWritten); + + [LibraryImport(DllName, EntryPoint = "daveDecryptorGetMaxPlaintextByteSize")] + public static partial nuint DecryptorGetMaxPlaintextByteSize(DecryptorHandle decryptor, MediaType mediaType, nuint encryptedFrameSize); + + [LibraryImport(DllName, EntryPoint = "daveDecryptorGetStats")] + public static partial void DecryptorGetStats(DecryptorHandle decryptor, MediaType mediaType, DecryptorStats* stats); + + [LibraryImport(DllName, EntryPoint = "daveSetLogSinkCallback")] + public static partial void SetLogSinkCallback(delegate* callback); +} diff --git a/NetCord/Gateway/Voice/DaveMlsInvalidCommitWelcomeProperties.cs b/NetCord/Gateway/Voice/DaveMlsInvalidCommitWelcomeProperties.cs new file mode 100644 index 000000000..d206d73b2 --- /dev/null +++ b/NetCord/Gateway/Voice/DaveMlsInvalidCommitWelcomeProperties.cs @@ -0,0 +1,9 @@ +using System.Text.Json.Serialization; + +namespace NetCord.Gateway.Voice; + +internal class DaveMlsInvalidCommitWelcomeProperties(int transitionId) +{ + [JsonPropertyName("transition_id")] + public int TransitionId { get; set; } = transitionId; +} diff --git a/NetCord/Gateway/Voice/DaveTransitionReadyProperties.cs b/NetCord/Gateway/Voice/DaveTransitionReadyProperties.cs new file mode 100644 index 000000000..a8d53acfe --- /dev/null +++ b/NetCord/Gateway/Voice/DaveTransitionReadyProperties.cs @@ -0,0 +1,9 @@ +using System.Text.Json.Serialization; + +namespace NetCord.Gateway.Voice; + +internal class DaveTransitionReadyProperties(ushort transitionId) +{ + [JsonPropertyName("transition_id")] + public ushort TransitionId { get; set; } = transitionId; +} diff --git a/NetCord/Gateway/Voice/GatewayClientExtensions.cs b/NetCord/Gateway/Voice/GatewayClientExtensions.cs index e03bf5215..111ba4750 100644 --- a/NetCord/Gateway/Voice/GatewayClientExtensions.cs +++ b/NetCord/Gateway/Voice/GatewayClientExtensions.cs @@ -46,7 +46,7 @@ public static async Task JoinVoiceChannelAsync(this GatewayClient c (state, server) = await WaitForEventsAsync(tokenSource.Token).ConfigureAwait(false); } - return new(userId, state.SessionId, server.Endpoint!, guildId, server.Token, configuration); + return new(userId, state.SessionId, server.Endpoint!, guildId, channelId, server.Token, configuration); ValueTask HandleVoiceStateUpdateAsync(VoiceState arg) { diff --git a/NetCord/Gateway/Voice/IVoiceClientCache.cs b/NetCord/Gateway/Voice/IVoiceClientCache.cs index 7c356ec32..35ebe4827 100644 --- a/NetCord/Gateway/Voice/IVoiceClientCache.cs +++ b/NetCord/Gateway/Voice/IVoiceClientCache.cs @@ -3,11 +3,13 @@ public interface IVoiceClientCache : IDictionaryProvider, IDisposable { public uint Ssrc { get; } - public IReadOnlyDictionary Ssrcs { get; } - public IReadOnlyDictionary Users { get; } + public IReadOnlySet Users { get; } + public IReadOnlyDictionary UserSsrcs { get; } + public IReadOnlyDictionary SsrcUsers { get; } public IVoiceClientCache CacheCurrentSsrc(uint ssrc); - public IVoiceClientCache CacheUser(ulong userId, uint ssrc); + public IVoiceClientCache CacheUsers(IReadOnlyList userId); + public IVoiceClientCache CacheUserSsrc(ulong userId, uint ssrc); public IVoiceClientCache RemoveUser(ulong userId); } diff --git a/NetCord/Gateway/Voice/ImmutableVoiceClientCache.cs b/NetCord/Gateway/Voice/ImmutableVoiceClientCache.cs index 0a0794a54..3ea029f80 100644 --- a/NetCord/Gateway/Voice/ImmutableVoiceClientCache.cs +++ b/NetCord/Gateway/Voice/ImmutableVoiceClientCache.cs @@ -46,73 +46,93 @@ internal static ImmutableVoiceClientCache FromJson(JsonVoiceClientCache jsonMode private ImmutableVoiceClientCache() { - _ssrcs = ImmutableDictionary.Empty; - _users = ImmutableDictionary.Empty; + _users = ImmutableHashSet.Empty; + _userSsrcs = ImmutableDictionary.Empty; + _ssrcUsers = ImmutableDictionary.Empty; } private ImmutableVoiceClientCache(JsonVoiceClientCache jsonModel) { _ssrc = jsonModel.Ssrc; - _ssrcs = jsonModel.Ssrcs.ToImmutableDictionary(); - _users = jsonModel.Users.ToImmutableDictionary(); + _users = [.. jsonModel.Users]; + _userSsrcs = jsonModel.UserSsrcs.ToImmutableDictionary(); + _ssrcUsers = jsonModel.SsrcUsers.ToImmutableDictionary(); } - private ImmutableVoiceClientCache(uint ssrc, ImmutableDictionary ssrcs, ImmutableDictionary users) + private ImmutableVoiceClientCache(uint ssrc, ImmutableHashSet users, ImmutableDictionary userSsrcs, ImmutableDictionary ssrcUsers) { _ssrc = ssrc; - _ssrcs = ssrcs; _users = users; + _userSsrcs = userSsrcs; + _ssrcUsers = ssrcUsers; } - private static ImmutableVoiceClientCache Create(uint ssrc, ImmutableDictionary ssrcs, ImmutableDictionary users) + private static ImmutableVoiceClientCache Create(uint ssrc, ImmutableHashSet users, ImmutableDictionary userSsrcs, ImmutableDictionary ssrcUsers) { - return new(ssrc, ssrcs, users); + return new(ssrc, users, userSsrcs, ssrcUsers); } public uint Ssrc => _ssrc; - public IReadOnlyDictionary Ssrcs => _ssrcs; - public IReadOnlyDictionary Users => _users; + public IReadOnlySet Users => _users; + public IReadOnlyDictionary UserSsrcs => _userSsrcs; + public IReadOnlyDictionary SsrcUsers => _ssrcUsers; #pragma warning disable IDE0032 // Use auto property private readonly uint _ssrc; #pragma warning restore IDE0032 // Use auto property - private readonly ImmutableDictionary _ssrcs; - private readonly ImmutableDictionary _users; + private readonly ImmutableHashSet _users; + private readonly ImmutableDictionary _userSsrcs; + private readonly ImmutableDictionary _ssrcUsers; public JsonVoiceClientCache ToJsonModel() { return new() { Ssrc = _ssrc, - Ssrcs = _ssrcs, - Users = _users, + Users = _users.ToArray(), + UserSsrcs = _userSsrcs, + SsrcUsers = _ssrcUsers, }; } public IVoiceClientCache CacheCurrentSsrc(uint ssrc) { return Create(ssrc, - _ssrcs, - _users); + _users, + _userSsrcs, + _ssrcUsers); } - public IVoiceClientCache CacheUser(ulong userId, uint ssrc) + public IVoiceClientCache CacheUsers(IReadOnlyList userId) { return Create(_ssrc, - _ssrcs.SetItem(userId, ssrc), - _users.SetItem(ssrc, userId)); + _users.Union(userId), + _userSsrcs, + _ssrcUsers); + } + + public IVoiceClientCache CacheUserSsrc(ulong userId, uint ssrc) + { + return Create(_ssrc, + _users, + _userSsrcs.SetItem(userId, ssrc), + _ssrcUsers.SetItem(ssrc, userId)); } public IVoiceClientCache RemoveUser(ulong userId) { - var ssrcs = _ssrcs; + var userSsrcs = _userSsrcs; - if (!ssrcs.TryGetValue(userId, out var ssrc)) - return this; + if (!userSsrcs.TryGetValue(userId, out var ssrc)) + return Create(_ssrc, + _users.Remove(userId), + userSsrcs, + _ssrcUsers); return Create(_ssrc, - ssrcs.Remove(userId), - _users.Remove(ssrc)); + _users.Remove(userId), + userSsrcs.Remove(userId), + _ssrcUsers.Remove(ssrc)); } public IReadOnlyDictionary CreateDictionary(IEnumerable source, Func keySelector, Func elementSelector) diff --git a/NetCord/Gateway/Voice/JsonModels/JsonDaveExecuteTransition.cs b/NetCord/Gateway/Voice/JsonModels/JsonDaveExecuteTransition.cs new file mode 100644 index 000000000..163c19d6c --- /dev/null +++ b/NetCord/Gateway/Voice/JsonModels/JsonDaveExecuteTransition.cs @@ -0,0 +1,9 @@ +using System.Text.Json.Serialization; + +namespace NetCord.Gateway.Voice.JsonModels; + +internal class JsonDaveExecuteTransition +{ + [JsonPropertyName("transition_id")] + public ushort TransitionId { get; set; } +} diff --git a/NetCord/Gateway/Voice/JsonModels/JsonDavePrepareEpoch.cs b/NetCord/Gateway/Voice/JsonModels/JsonDavePrepareEpoch.cs new file mode 100644 index 000000000..9007027ac --- /dev/null +++ b/NetCord/Gateway/Voice/JsonModels/JsonDavePrepareEpoch.cs @@ -0,0 +1,12 @@ +using System.Text.Json.Serialization; + +namespace NetCord.Gateway.Voice.JsonModels; + +internal class JsonDavePrepareEpoch +{ + [JsonPropertyName("protocol_version")] + public ushort ProtocolVersion { get; set; } + + [JsonPropertyName("epoch")] + public int Epoch { get; set; } +} diff --git a/NetCord/Gateway/Voice/JsonModels/JsonDavePrepareTransition.cs b/NetCord/Gateway/Voice/JsonModels/JsonDavePrepareTransition.cs new file mode 100644 index 000000000..5e0f7fa63 --- /dev/null +++ b/NetCord/Gateway/Voice/JsonModels/JsonDavePrepareTransition.cs @@ -0,0 +1,12 @@ +using System.Text.Json.Serialization; + +namespace NetCord.Gateway.Voice.JsonModels; + +internal class JsonDavePrepareTransition +{ + [JsonPropertyName("protocol_version")] + public ushort ProtocolVersion { get; set; } + + [JsonPropertyName("transition_id")] + public ushort TransitionId { get; set; } +} diff --git a/NetCord/Gateway/Voice/JsonModels/JsonSessionDescription.cs b/NetCord/Gateway/Voice/JsonModels/JsonSessionDescription.cs index e595027cd..2b49c4863 100644 --- a/NetCord/Gateway/Voice/JsonModels/JsonSessionDescription.cs +++ b/NetCord/Gateway/Voice/JsonModels/JsonSessionDescription.cs @@ -9,6 +9,9 @@ internal class JsonSessionDescription [JsonPropertyName("secret_key")] public byte[] SecretKey { get; set; } + [JsonPropertyName("dave_protocol_version")] + public ushort DaveProtocolVersion { get; set; } + public class ByteArrayOfLength32Converter : JsonConverter { public override byte[]? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) diff --git a/NetCord/Gateway/Voice/JsonModels/JsonVoiceClientCache.cs b/NetCord/Gateway/Voice/JsonModels/JsonVoiceClientCache.cs index 0d47b725d..dcd75850e 100644 --- a/NetCord/Gateway/Voice/JsonModels/JsonVoiceClientCache.cs +++ b/NetCord/Gateway/Voice/JsonModels/JsonVoiceClientCache.cs @@ -7,9 +7,12 @@ public class JsonVoiceClientCache [JsonPropertyName("ssrc")] public uint Ssrc { get; set; } - [JsonPropertyName("ssrcs")] - public IReadOnlyDictionary Ssrcs { get; set; } - [JsonPropertyName("users")] - public IReadOnlyDictionary Users { get; set; } + public IReadOnlyList Users { get; set; } + + [JsonPropertyName("user_ssrcs")] + public IReadOnlyDictionary UserSsrcs { get; set; } + + [JsonPropertyName("ssrc_users")] + public IReadOnlyDictionary SsrcUsers { get; set; } } diff --git a/NetCord/Gateway/Voice/Streams/VoiceOutStream.cs b/NetCord/Gateway/Voice/Streams/VoiceOutStream.cs index e79297ca9..861615ce2 100644 --- a/NetCord/Gateway/Voice/Streams/VoiceOutStream.cs +++ b/NetCord/Gateway/Voice/Streams/VoiceOutStream.cs @@ -48,40 +48,82 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati public override void Write(ReadOnlySpan buffer) { - if (client._udpState is not { Connection: var connection, Encryption: var encryption }) + if (client._udpState is not { Connection: var connection, Encryption: var encryption, DaveSession: var session }) { ThrowConnectionNotStarted(); return; } + byte[]? daveArray = null; + if (session.GetProtocolVersion() is not 0) + { + var daveEncryptor = session.GetEncryptor(); + var length = daveEncryptor.GetMaxCiphertextByteSize(buffer.Length); + daveArray = ArrayPool.Shared.Rent(length); + + int written = daveEncryptor.Encrypt(client.Cache.Ssrc, buffer, daveArray); + + if (written is -1) + { + ArrayPool.Shared.Return(daveArray); + return; + } + + buffer = daveArray.AsSpan(0, written); + } + int datagramLength = buffer.Length + encryption.Expansion + 12; - var array = ArrayPool.Shared.Rent(datagramLength); + var datagramArray = ArrayPool.Shared.Rent(datagramLength); + + WriteDatagram(buffer, new(datagramArray, 0, datagramLength), encryption); - WriteDatagram(buffer, new(array, 0, datagramLength), encryption); + if (daveArray is not null) + ArrayPool.Shared.Return(daveArray); - connection.Send(new(array, 0, datagramLength)); + connection.Send(new(datagramArray, 0, datagramLength)); - ArrayPool.Shared.Return(array); + ArrayPool.Shared.Return(datagramArray); } public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) { - if (client._udpState is not { Connection: var connection, Encryption: var encryption }) + if (client._udpState is not { Connection: var connection, Encryption: var encryption, DaveSession: var session }) { ThrowConnectionNotStarted(); return; } + byte[]? daveArray = null; + if (session.GetProtocolVersion() is not 0) + { + var daveEncryptor = session.GetEncryptor(); + var length = daveEncryptor.GetMaxCiphertextByteSize(buffer.Length); + daveArray = ArrayPool.Shared.Rent(length); + + int written = daveEncryptor.Encrypt(client.Cache.Ssrc, buffer.Span, daveArray); + + if (written is -1) + { + ArrayPool.Shared.Return(daveArray); + return; + } + + buffer = daveArray.AsMemory(0, written); + } + int datagramLength = buffer.Length + encryption.Expansion + 12; - var array = ArrayPool.Shared.Rent(datagramLength); + var datagramArray = ArrayPool.Shared.Rent(datagramLength); + + WriteDatagram(buffer.Span, new(datagramArray, 0, datagramLength), encryption); - WriteDatagram(buffer.Span, new(array, 0, datagramLength), encryption); + if (daveArray is not null) + ArrayPool.Shared.Return(daveArray); - await connection.SendAsync(new(array, 0, datagramLength), cancellationToken).ConfigureAwait(false); + await connection.SendAsync(new(datagramArray, 0, datagramLength), cancellationToken).ConfigureAwait(false); - ArrayPool.Shared.Return(array); + ArrayPool.Shared.Return(datagramArray); } private void WriteDatagram(ReadOnlySpan buffer, Span datagram, IVoiceEncryption encryption) diff --git a/NetCord/Gateway/Voice/VoiceClient.DaveSession.cs b/NetCord/Gateway/Voice/VoiceClient.DaveSession.cs new file mode 100644 index 000000000..5f3ed32b6 --- /dev/null +++ b/NetCord/Gateway/Voice/VoiceClient.DaveSession.cs @@ -0,0 +1,422 @@ +using System.Buffers; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +using static NetCord.Gateway.Voice.Dave; + +namespace NetCord.Gateway.Voice; + +#pragma warning disable IDE0290 // Use primary constructor + +public partial class VoiceClient +{ + internal readonly ref struct DaveEncryptor(EncryptorHandle encryptor) + { + public unsafe int Encrypt(uint ssrc, ReadOnlySpan frame, Span encryptedFrame) + { + EncryptorResultCode result; + nuint bytesWritten; + + fixed (byte* frameP = frame, encryptedFrameP = encryptedFrame) + result = EncryptorEncrypt(encryptor, MediaType.Audio, ssrc, frameP, (uint)frame.Length, encryptedFrameP, (nuint)encryptedFrame.Length, out bytesWritten); + + return result is EncryptorResultCode.Success ? (int)bytesWritten : -1; + } + + public readonly int GetMaxCiphertextByteSize(int plaintextByteSize) + => (int)EncryptorGetMaxCiphertextByteSize(encryptor, MediaType.Audio, (nuint)plaintextByteSize); + } + + internal struct DaveDecryptor(DecryptorHandle decryptor) + { + public readonly unsafe int Decrypt(uint ssrc, ReadOnlySpan encryptedFrame, Span frame) + { + DecryptorResultCode result; + nuint bytesWritten; + + fixed (byte* encryptedFrameP = encryptedFrame, frameP = frame) + result = DecryptorDecrypt(decryptor, MediaType.Audio, encryptedFrameP, (nuint)encryptedFrame.Length, frameP, (nuint)frame.Length, out bytesWritten); + + return result is DecryptorResultCode.Success ? (int)bytesWritten : -1; + } + + public readonly int GetMaxPlaintextByteSize(int ciphertextByteSize) + => (int)DecryptorGetMaxPlaintextByteSize(decryptor, MediaType.Audio, (nuint)ciphertextByteSize); + } + + internal class DaveSession : IDisposable + { + private const int MaxSnowflakeCStringSize = 21; + + private const int MlsNewGroupExpectedEpoch = 1; + + private ushort _latestPreparedTransitionVersion; + + private readonly Dictionary _transitions; + + private readonly EncryptorHandle _encryptor; + private readonly ConcurrentDictionary _decryptors; + + private readonly VoiceClient _client; + private readonly SessionHandle _session; + + public unsafe DaveSession(VoiceClient client, delegate* mlsFailureCallback, void* userData) + { + _transitions = []; + + _encryptor = EncryptorCreate(); + _decryptors = []; + + _client = client; + _session = SessionCreate(null, null, mlsFailureCallback, userData); + } + + static unsafe DaveSession() + { + SetLogSinkCallback(&LogSink); + } + + private static unsafe void LogSink(LoggingSeverity severity, byte* file, int line, byte* message) + { + LogSinkInternal(severity, file, line, message); + } + + [Conditional("DEBUG")] + private static unsafe void LogSinkInternal(LoggingSeverity severity, byte* file, int line, byte* message) + { + var fileString = Marshal.PtrToStringUTF8((nint)file); + var messageString = Marshal.PtrToStringUTF8((nint)message); + Debug.WriteLine($"Dave at {fileString}:{line}: {messageString}"); + } + + public static ushort GetMaxSupportedProtocolVersion() + { + return MaxSupportedProtocolVersion(); + } + + public ushort GetProtocolVersion() + { + return SessionGetProtocolVersion(_session); + } + + public DaveEncryptor GetEncryptor() + { + return new(_encryptor); + } + + public DaveDecryptor? GetDecryptor(uint ssrc) + { + return _decryptors.TryGetValue(ssrc, out var decryptor) ? new(decryptor) : null; + } + + public void OnSpeaking(ulong userId, uint ssrc) + { + SetupKeyRatchetForUser(userId, ssrc, _latestPreparedTransitionVersion); + } + + public void OnClientDisconnect(ulong userId) + { + if (_client.Cache.UserSsrcs.TryGetValue(userId, out var ssrc) + && _decryptors.TryRemove(ssrc, out var decryptor)) + decryptor.Dispose(); + } + + public ValueTask OnSessionDescriptionAsync(ConnectionState connectionState, ushort protocolVersion) + { + return HandleInitAsync(connectionState, protocolVersion); + } + + public ValueTask OnPrepareTransitionAsync(ConnectionState connectionState, ushort transitionId, ushort protocolVersion) + { + PrepareRatchets(transitionId, protocolVersion); + if (transitionId is not InitTransitionId) + { + SetDecryptorsPassthroughMode(protocolVersion is DisabledVersion); + return SendTransitionReadyAsync(connectionState, transitionId); + } + + return default; + } + + public void OnExecuteTransition(ushort transitionId) + { + HandleExecuteTransition(transitionId); + } + + public ValueTask OnPrepareEpoch(ConnectionState connectionState, int epoch, ushort protocolVersion) + { + if (epoch is MlsNewGroupExpectedEpoch) + { + InitSession(protocolVersion); + return SendMlsKeyPackageAsync(connectionState); + } + + return default; + } + + public unsafe void OnMlsExternalSender(ReadOnlySpan externalSender) + { + fixed (byte* p = externalSender) + SessionSetExternalSender(_session, p, (nuint)externalSender.Length); + } + + public unsafe ValueTask OnMlsProposalsAsync(ConnectionState connectionState, ReadOnlySpan proposals) + { + var recognizedUserIds = GetRecognizedUserIds(out var buffer); + + byte* commitWelcome; + nuint commitWelcomeLength; + + fixed (byte* proposalsP = proposals) + SessionProcessProposals(_session, proposalsP, (nuint)proposals.Length, recognizedUserIds, (nuint)recognizedUserIds.Length, out commitWelcome, out commitWelcomeLength); + + FreeRecognizedUserIdsBuffer(buffer); + + if (commitWelcome is not null) + return SendMlsCommitWelcomeAsync(connectionState, new(commitWelcome, (int)commitWelcomeLength)); + + return default; + } + + public ValueTask OnMlsPrepareCommitTransitionAsync(ConnectionState connectionState, ushort transitionId, ReadOnlySpan commit) + { + using var commitResultHandle = SessionProcessCommit(_session, commit, (nuint)commit.Length); + + if (CommitResultIsIgnored(commitResultHandle)) + return default; + + return HandleRosterUpdatedAsync(connectionState, transitionId, CommitResultIsFailed(commitResultHandle)); + } + + public ValueTask OnMlsWelcomeAsync(ConnectionState connectionState, ushort transitionId, ReadOnlySpan welcome) + { + var recognizedUserIds = GetRecognizedUserIds(out var buffer); + + using var welcomeResult = SessionProcessWelcome(_session, welcome, (nuint)welcome.Length, recognizedUserIds, (nuint)recognizedUserIds.Length); + + FreeRecognizedUserIdsBuffer(buffer); + + return HandleRosterUpdatedAsync(connectionState, transitionId, welcomeResult.IsInvalid); + } + + private void SetDecryptorsPassthroughMode(bool passthroughMode) + { + foreach (var decryptor in _decryptors.Values) + DecryptorTransitionToPassthroughMode(decryptor, passthroughMode); + } + + private async ValueTask HandleRosterUpdatedAsync(ConnectionState connectionState, ushort transitionId, bool isFailed) + { + var joinedGroup = !isFailed; + if (joinedGroup) + { + PrepareRatchets(transitionId, GetProtocolVersion()); + if (transitionId is not InitTransitionId) + await SendTransitionReadyAsync(connectionState, transitionId).ConfigureAwait(false); + } + else + { + await SendMlsInvalidCommitWelcomeAsync(connectionState, transitionId).ConfigureAwait(false); + await SendMlsKeyPackageAsync(connectionState).ConfigureAwait(false); + } + } + + private unsafe ReadOnlySpan GetRecognizedUserIds(out nint[] pointers) + { + var users = _client.Cache.Users; + + int count = users.Count + 1; + + pointers = ArrayPool.Shared.Rent(count); + + var buffer = (byte*)NativeMemory.Alloc((nuint)(count * MaxSnowflakeCStringSize)); + + var result = pointers.AsSpan(0, count); + + int i = 0; + int written = 0; + foreach (var userId in users) + AddUserId(ref result[i++], ref written, buffer, userId); + + AddUserId(ref result[i], ref written, buffer, _client.UserId); + + return result; + + static void AddUserId(ref nint pointer, ref int written, byte* buffer, ulong userId) + { + var start = buffer + written; + var writtenSpan = SnowflakeToCString(userId, new(start, MaxSnowflakeCStringSize)); + pointer = (nint)start; + written += writtenSpan.Length; + } + } + + private static unsafe void FreeRecognizedUserIdsBuffer(nint[] recognizedUserIds) + { + var buffer = (byte*)recognizedUserIds[0]; + NativeMemory.Free(buffer); + ArrayPool.Shared.Return(recognizedUserIds); + } + + private ValueTask HandleInitAsync(ConnectionState connectionState, ushort protocolVersion) + { + InitSession(protocolVersion); + + if (protocolVersion > DisabledVersion) + return SendMlsKeyPackageAsync(connectionState); + else + { + PrepareRatchets(InitTransitionId, protocolVersion); + HandleExecuteTransition(InitTransitionId); + } + + return default; + } + + private void HandleExecuteTransition(ushort transitionId) + { + if (!_transitions.Remove(transitionId, out var protocolVersion)) + return; + + if (protocolVersion is DisabledVersion) + SessionReset(_session); + + SetupEncryptionKeyRatchet(protocolVersion); + } + + private void PrepareRatchets(ushort transitionId, ushort protocolVersion) + { + foreach (var pair in _client.Cache.UserSsrcs) + SetupKeyRatchetForUser(pair.Key, pair.Value, protocolVersion); + + if (transitionId is InitTransitionId) + SetupEncryptionKeyRatchet(protocolVersion); + else + _transitions[transitionId] = protocolVersion; + + _latestPreparedTransitionVersion = protocolVersion; + } + + private void SetupKeyRatchetForUser(ulong userId, uint ssrc, ushort protocolVersion) + { + using var keyRatchet = GetUserKeyRatchet(userId, protocolVersion); + + var decryptor = _decryptors.GetOrAdd(ssrc, ssrc => + { + var decryptor = DecryptorCreate(); + DecryptorTransitionToPassthroughMode(decryptor, protocolVersion is DisabledVersion); + return decryptor; + }); + + if (keyRatchet is not null) + DecryptorTransitionToKeyRatchet(decryptor, keyRatchet); + } + + private void SetupEncryptionKeyRatchet(ushort protocolVersion) + { + using var keyRatchet = GetUserKeyRatchet(_client.UserId, protocolVersion); + if (keyRatchet is not null) + EncryptorSetKeyRatchet(_encryptor, keyRatchet); + } + + [SkipLocalsInit] + private KeyRatchetHandle? GetUserKeyRatchet(ulong userId, ushort protocolVersion) + { + if (protocolVersion is DisabledVersion) + return null; + + Span byteUserId = stackalloc byte[MaxSnowflakeCStringSize]; + byteUserId = SnowflakeToCString(userId, byteUserId); + return SessionGetKeyRatchet(_session, byteUserId); + } + + private async ValueTask SendMlsKeyPackageAsync(ConnectionState connectionState) + { + byte[] payload; + int payloadLength; + unsafe + { + SessionGetMarshalledKeyPackage(_session, out var keyPackage, out var length); + + payloadLength = (int)length + 1; + + payload = ArrayPool.Shared.Rent(payloadLength); + + new ReadOnlySpan(keyPackage, (int)length).CopyTo(payload.AsSpan(1)); + } + + payload[0] = (byte)VoiceOpcode.DaveMlsKeyPackage; + + await _client.SendConnectionPayloadAsync(connectionState, payload.AsMemory(0, payloadLength), _client._internalBinaryPayloadProperties).ConfigureAwait(false); + + ArrayPool.Shared.Return(payload); + } + + private ValueTask SendTransitionReadyAsync(ConnectionState connectionState, ushort transitionId) + { + VoicePayloadProperties readyPayload = new(VoiceOpcode.DaveTransitionReady, new(transitionId)); + + return _client.SendConnectionPayloadAsync(connectionState, readyPayload.Serialize(Serialization.Default.VoicePayloadPropertiesDaveTransitionReadyProperties), _client._internalTextPayloadProperties); + } + + private ValueTask SendMlsCommitWelcomeAsync(ConnectionState connectionState, ReadOnlySpan commitWelcomeMessage) + { + int payloadLength = commitWelcomeMessage.Length + 1; + var payload = ArrayPool.Shared.Rent(payloadLength); + + commitWelcomeMessage.CopyTo(payload.AsSpan(1)); + payload[0] = (byte)VoiceOpcode.DaveMlsCommitWelcome; + + return ContinueAsync(_client, connectionState, payload, payloadLength); + + static async ValueTask ContinueAsync(VoiceClient client, ConnectionState connectionState, byte[] payload, int payloadLength) + { + await client.SendConnectionPayloadAsync(connectionState, payload.AsMemory(0, payloadLength), client._internalBinaryPayloadProperties).ConfigureAwait(false); + + ArrayPool.Shared.Return(payload); + } + } + + private ValueTask SendMlsInvalidCommitWelcomeAsync(ConnectionState connectionState, ushort transitionId) + { + VoicePayloadProperties invalidCommitWelcomePayload = new(VoiceOpcode.DaveMlsInvalidCommitWelcome, new(transitionId)); + + return _client.SendConnectionPayloadAsync(connectionState, invalidCommitWelcomePayload.Serialize(Serialization.Default.VoicePayloadPropertiesDaveMlsInvalidCommitWelcomeProperties), _client._internalTextPayloadProperties); + } + + [SkipLocalsInit] + private void InitSession(ushort protocolVersion) + { + Span userId = stackalloc byte[MaxSnowflakeCStringSize]; + userId = SnowflakeToCString(_client.UserId, userId); + SessionInit(_session, protocolVersion, _client.ChannelId, userId); + } + + private static Span SnowflakeToCString(ulong snowflake, Span buffer) + { + if (!snowflake.TryFormat(buffer, out int bytesWritten)) + ThrowFailedToFormatSnowflake(); + + buffer[bytesWritten] = 0; + + return buffer[..(bytesWritten + 1)]; + } + + [DoesNotReturn] + private static void ThrowFailedToFormatSnowflake() + { + throw new InvalidOperationException("Failed to format snowflake."); + } + + public void Dispose() + { + _session.Dispose(); + _encryptor.Dispose(); + foreach (var decryptor in _decryptors.Values) + decryptor.Dispose(); + } + } +} diff --git a/NetCord/Gateway/Voice/VoiceClient.cs b/NetCord/Gateway/Voice/VoiceClient.cs index 588594ed9..d68d4b46a 100644 --- a/NetCord/Gateway/Voice/VoiceClient.cs +++ b/NetCord/Gateway/Voice/VoiceClient.cs @@ -1,12 +1,15 @@ using System.Buffers; using System.Buffers.Binary; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Text; using System.Text.Json; +using NetCord.Gateway.Voice.BinaryModels; using NetCord.Gateway.Voice.Encryption; using NetCord.Gateway.Voice.JsonModels; using NetCord.Gateway.Voice.UdpSockets; +using NetCord.Gateway.WebSockets; using NetCord.Logging; using WebSocketCloseStatus = System.Net.WebSockets.WebSocketCloseStatus; @@ -24,11 +27,13 @@ public override void Abort() } } - internal class UdpState(IUdpConnection connection, IVoiceEncryption encryption) : IDisposable + internal class UdpState(IUdpConnection connection, IVoiceEncryption encryption, DaveSession daveSession) : IDisposable { public IUdpConnection Connection => connection; public IVoiceEncryption Encryption => encryption; + public DaveSession DaveSession => daveSession; + private CancellationTokenProvider? _closedTokenProvider; public bool TryIndicateConnecting(out CancellationToken closedCancellationToken) @@ -59,6 +64,7 @@ public void Dispose() { connection.Dispose(); encryption.Dispose(); + daveSession.Dispose(); _closedTokenProvider?.Dispose(); } } @@ -77,6 +83,8 @@ public void Dispose() public ulong GuildId { get; } + public ulong ChannelId { get; } + public string Token { get; } /// @@ -95,15 +103,17 @@ public void Dispose() private readonly IVoiceEncryptionProvider _encryptionProvider; private readonly IVoiceReceiveHandler _receiveHandler; private readonly TimeSpan _externalSocketAddressDiscoveryTimeout; + private readonly GCHandle _loggerHandle; internal UdpState? _udpState; - public VoiceClient(ulong userId, string sessionId, string endpoint, ulong guildId, string token, VoiceClientConfiguration? configuration = null) : base(configuration ??= new()) + public VoiceClient(ulong userId, string sessionId, string endpoint, ulong guildId, ulong channelId, string token, VoiceClientConfiguration? configuration = null) : base(configuration ??= new()) { UserId = userId; SessionId = sessionId; Uri = new($"wss://{Endpoint = endpoint}?v={(int)configuration.Version.GetValueOrDefault(VoiceApiVersion.V8)}", UriKind.Absolute); GuildId = guildId; + ChannelId = channelId; Token = token; var cacheProvider = configuration.CacheProvider ?? ImmutableVoiceClientCacheProvider.Empty; @@ -112,13 +122,14 @@ public void Dispose() _encryptionProvider = configuration.EncryptionProvider ?? VoiceEncryptionProvider.Instance; _receiveHandler = configuration.ReceiveHandler ?? NullVoiceReceiveHandler.Instance; _externalSocketAddressDiscoveryTimeout = configuration.ExternalSocketAddressDiscoveryTimeout.GetValueOrDefault(new(5 * TimeSpan.TicksPerSecond)); + _loggerHandle = GCHandle.Alloc(_logger); } private protected override ValueTask SendIdentifyAsync(ConnectionState connectionState, CancellationToken cancellationToken = default) { - var serializedPayload = new VoicePayloadProperties(VoiceOpcode.Identify, new(GuildId, UserId, SessionId, Token)).Serialize(Serialization.Default.VoicePayloadPropertiesVoiceIdentifyProperties); + var serializedPayload = new VoicePayloadProperties(VoiceOpcode.Identify, new(GuildId, UserId, SessionId, Token, DaveSession.GetMaxSupportedProtocolVersion())).Serialize(Serialization.Default.VoicePayloadPropertiesVoiceIdentifyProperties); _latencyTimer.Start(); - return SendConnectionPayloadAsync(connectionState, serializedPayload, _internalPayloadProperties, cancellationToken); + return SendConnectionPayloadAsync(connectionState, serializedPayload, _internalTextPayloadProperties, cancellationToken); } private VoiceState CreateState() @@ -190,23 +201,99 @@ private ValueTask TryResumeAsync(ConnectionState connectionState, int sequenceNu { var serializedPayload = new VoicePayloadProperties(VoiceOpcode.Resume, new(GuildId, SessionId, Token, sequenceNumber)).Serialize(Serialization.Default.VoicePayloadPropertiesVoiceResumeProperties); _latencyTimer.Start(); - return SendConnectionPayloadAsync(connectionState, serializedPayload, _internalPayloadProperties, cancellationToken); + return SendConnectionPayloadAsync(connectionState, serializedPayload, _internalTextPayloadProperties, cancellationToken); } private protected override ValueTask HeartbeatAsync(ConnectionState connectionState, CancellationToken cancellationToken = default) { var serializedPayload = new VoicePayloadProperties(VoiceOpcode.Heartbeat, new(Environment.TickCount, SequenceNumber)).Serialize(Serialization.Default.VoicePayloadPropertiesVoiceHeartbeatProperties); _latencyTimer.Start(); - return SendConnectionPayloadAsync(connectionState, serializedPayload, _internalPayloadProperties, cancellationToken); + return SendConnectionPayloadAsync(connectionState, serializedPayload, _internalTextPayloadProperties, cancellationToken); + } + + private protected override ValueTask ProcessPayloadAsync(State state, ConnectionState connectionState, WebSocketMessageType messageType, ReadOnlySpan payload) + { + if (messageType is WebSocketMessageType.Text) + { + var jsonPayload = JsonSerializer.Deserialize(payload, Serialization.Default.JsonVoicePayload)!; + return HandleJsonPayloadAsync(state, connectionState, jsonPayload); + } + else + { + BinaryVoicePayload binaryPayload = new(payload); + return HandleBinaryPayloadAsync(connectionState, binaryPayload); + } } - private protected override ValueTask ProcessPayloadAsync(State state, ConnectionState connectionState, ReadOnlySpan payload) + private ValueTask HandleBinaryPayloadAsync(ConnectionState connectionState, BinaryVoicePayload payload) { - var jsonPayload = JsonSerializer.Deserialize(payload, Serialization.Default.JsonVoicePayload)!; - return HandlePayloadAsync(state, connectionState, jsonPayload); + SequenceNumber = payload.SequencyNumber; + + switch (payload.Opcode) + { + case VoiceOpcode.DaveMlsExternalSender: + { + if (_udpState is not { DaveSession: var session }) + return default; + + var externalSender = payload.Data; + + session.OnMlsExternalSender(externalSender); + } + return default; + case VoiceOpcode.DaveMlsProposals: + { + if (_udpState is not { DaveSession: var session }) + return default; + + var proposals = payload.Data; + + return session.OnMlsProposalsAsync(connectionState, proposals); + } + case VoiceOpcode.DaveMlsAnnounceCommitTransition: + { + if (_udpState is not { DaveSession: var session }) + return default; + + var data = payload.Data; + + var transitionId = BinaryPrimitives.ReadUInt16BigEndian(data); + + var commitBytes = data[2..]; + + return session.OnMlsPrepareCommitTransitionAsync(connectionState, transitionId, commitBytes); + } + case VoiceOpcode.DaveMlsWelcome: + { + if (_udpState is not { DaveSession: var session }) + return default; + + var data = payload.Data; + + var transitionId = BinaryPrimitives.ReadUInt16BigEndian(data); + + var welcome = data[2..]; + + return session.OnMlsWelcomeAsync(connectionState, transitionId, welcome); + } + default: + return default; + } } - private async ValueTask HandlePayloadAsync(State state, ConnectionState connectionState, JsonVoicePayload payload) + private unsafe static void LogMlsFailure(byte* source, byte* reason, void* userData) + { + var logger = (IWebSocketLogger)GCHandle.FromIntPtr((nint)userData).Target!; + + if (logger.IsEnabled(LogLevel.Error)) + { + var sourceStr = Marshal.PtrToStringUTF8((nint)source); + var reasonStr = Marshal.PtrToStringUTF8((nint)reason); + logger.Log(LogLevel.Error, (Source: sourceStr, Reason: reasonStr), null, static (s, e) => $"An MLS error occured: {s.Source} {s.Reason}"); + } + } + + private async ValueTask HandleJsonPayloadAsync(State state, ConnectionState connectionState, JsonVoicePayload payload) { if (payload.SequenceNumber is int sequenceNumber) SequenceNumber = sequenceNumber; @@ -223,7 +310,11 @@ private async ValueTask HandlePayloadAsync(State state, ConnectionState connecti var udpConnection = _udpConnectionProvider.CreateConnection(ip, port); var encryption = _encryptionProvider.GetEncryption(ready.Modes); - UdpState newUdpState = new(udpConnection, encryption); + UdpState newUdpState; + unsafe + { + newUdpState = new(udpConnection, encryption, new(this, &LogMlsFailure, (void*)GCHandle.ToIntPtr(_loggerHandle))); + } if (Interlocked.CompareExchange(ref _udpState, newUdpState, null) is not null) { @@ -268,7 +359,7 @@ private async ValueTask HandlePayloadAsync(State state, ConnectionState connecti Log(LogLevel.Debug, null, null, static (s, e) => "Selecting a protocol."); VoicePayloadProperties protocolPayload = new(VoiceOpcode.SelectProtocol, new("udp", new(ip, port, encryptionName))); - await SendConnectionPayloadAsync(connectionState, protocolPayload.Serialize(Serialization.Default.VoicePayloadPropertiesProtocolProperties), _internalPayloadProperties).ConfigureAwait(false); + await SendConnectionPayloadAsync(connectionState, protocolPayload.Serialize(Serialization.Default.VoicePayloadPropertiesProtocolProperties), _internalTextPayloadProperties).ConfigureAwait(false); await updateLatencyTask; } @@ -277,12 +368,17 @@ private async ValueTask HandlePayloadAsync(State state, ConnectionState connecti { Log(LogLevel.Debug, null, null, static (s, e) => "Session description received."); - if (_udpState is not { Encryption: var encryption }) + if (_udpState is not { Encryption: var encryption, DaveSession: var session }) return; var sessionDescription = payload.Data.GetValueOrDefault().ToObject(Serialization.Default.JsonSessionDescription); + encryption.SetKey(sessionDescription.SecretKey); + var protocolVersion = sessionDescription.DaveProtocolVersion; + + await session.OnSessionDescriptionAsync(connectionState, protocolVersion).ConfigureAwait(false); + Log(LogLevel.Information, null, null, static (s, e) => "Ready."); var readyTask = InvokeEventAsync(_ready); @@ -296,7 +392,7 @@ private async ValueTask HandlePayloadAsync(State state, ConnectionState connecti { var json = payload.Data.GetValueOrDefault().ToObject(Serialization.Default.JsonSpeaking); - await InvokeEventAsync(_speaking, this, json, static json => new SpeakingEventArgs(json), static (client, json) => client.Cache = client.Cache.CacheUser(json.UserId, json.Ssrc)).ConfigureAwait(false); + await InvokeEventAsync(_speaking, this, json, static json => new SpeakingEventArgs(json), static (client, json) => client.Cache = client.Cache.CacheUserSsrc(json.UserId, json.Ssrc)).ConfigureAwait(false); } break; case VoiceOpcode.HeartbeatACK: @@ -338,7 +434,7 @@ private async ValueTask HandlePayloadAsync(State state, ConnectionState connecti Log(LogLevel.Debug, null, null, static (s, e) => "Client connect received."); var json = payload.Data.GetValueOrDefault().ToObject(Serialization.Default.JsonClientConnect); - await InvokeEventAsync(_userConnect, json, static json => new UserConnectEventArgs(json.UserIds)).ConfigureAwait(false); + await InvokeEventAsync(_userConnect, this, json, static json => new UserConnectEventArgs(json.UserIds), static (client, json) => client.Cache = client.Cache.CacheUsers(json.UserIds)).ConfigureAwait(false); } break; case VoiceOpcode.ClientDisconnect: @@ -346,7 +442,43 @@ private async ValueTask HandlePayloadAsync(State state, ConnectionState connecti Log(LogLevel.Debug, null, null, static (s, e) => "Client disconnect received."); var json = payload.Data.GetValueOrDefault().ToObject(Serialization.Default.JsonClientDisconnect); - await InvokeEventAsync(_userDisconnect, this, json, static json => new UserDisconnectEventArgs(json.UserId), static (client, json) => client.Cache = client.Cache.RemoveUser(json.UserId)).ConfigureAwait(false); + + var userDisconnectTask = InvokeEventAsync(_userDisconnect, this, json, static json => new UserDisconnectEventArgs(json.UserId), static (client, json) => client.Cache = client.Cache.RemoveUser(json.UserId)).ConfigureAwait(false); + + if (_udpState is { DaveSession: var session }) + session.OnClientDisconnect(json.UserId); + + await userDisconnectTask; + } + break; + case VoiceOpcode.DavePrepareTransition: + { + if (_udpState is not { DaveSession: var session }) + return; + + var prepareTransition = payload.Data.GetValueOrDefault().ToObject(Serialization.Default.JsonDavePrepareTransition); + + await session.OnPrepareTransitionAsync(connectionState, prepareTransition.TransitionId, prepareTransition.ProtocolVersion).ConfigureAwait(false); + } + break; + case VoiceOpcode.DaveExecuteTransition: + { + if (_udpState is not { DaveSession: var session }) + return; + + var executeTransition = payload.Data.GetValueOrDefault().ToObject(Serialization.Default.JsonDaveExecuteTransition); + + session.OnExecuteTransition(executeTransition.TransitionId); + } + break; + case VoiceOpcode.DavePrepareEpoch: + { + if (_udpState is not { DaveSession: var session }) + return; + + var prepareEpoch = payload.Data.GetValueOrDefault().ToObject(Serialization.Default.JsonDavePrepareEpoch); + + await session.OnPrepareEpoch(connectionState, prepareEpoch.Epoch, prepareEpoch.ProtocolVersion).ConfigureAwait(false); } break; } @@ -434,7 +566,7 @@ private static (string Ip, ushort Port) GetSocketAddress(ReadOnlySpan data private async void HandleDatagramReceive(ReadOnlyMemory datagram) { - if (_udpState is not { Encryption: var encryption }) + if (_udpState is not { Encryption: var encryption, DaveSession: var session }) return; var handlers = _voiceReceive; @@ -449,6 +581,9 @@ private async void HandleDatagramReceive(ReadOnlyMemory datagram) if (!result.Handle) return; + if (session.GetDecryptor(ssrc) is not { } decryptor) + return; + var framesMissed = result.FramesMissed; if (framesMissed is 0) @@ -475,10 +610,11 @@ private async void HandleDatagramReceive(ReadOnlyMemory datagram) ValueTask InvokeEventForReceivedFrameAsync() { - return InvokeEventWithDisposalAsync(handlers, (Encryption: encryption, PacketStorage: packetStorage, Ssrc: ssrc), static data => + return InvokeEventWithDisposalAsync(handlers, (Encryption: encryption, Decryptor: decryptor, PacketStorage: packetStorage, Ssrc: ssrc), static data => { - var packet = data.PacketStorage.Packet; - var encryption = data.Encryption; + var (encryption, decryptor, packetStorage, ssrc) = data; + + var packet = packetStorage.Packet; var plaintextLength = packet.PayloadLength - encryption.Expansion; var array = ArrayPool.Shared.Rent(plaintextLength); @@ -491,7 +627,11 @@ ValueTask InvokeEventForReceivedFrameAsync() : BinaryPrimitives.ReadUInt16BigEndian(packet.Datagram[(packet.HeaderLength + 2)..])) : 0; - return new VoiceReceiveEventArgs(array, extensionLength, plaintextLength - extensionLength, data.Ssrc); + int daveArrayLength = decryptor.GetMaxPlaintextByteSize(plaintextLength - extensionLength); + var daveArray = ArrayPool.Shared.Rent(daveArrayLength); + int written = decryptor.Decrypt(packet.Ssrc, plaintext[extensionLength..], daveArray); + + return new VoiceReceiveEventArgs(daveArray, 0, written, ssrc); }, args => { ArrayPool.Shared.Return(args._buffer!); @@ -549,6 +689,8 @@ protected override void Dispose(bool disposing) Cache.Dispose(); _udpState?.Dispose(); } + + _loggerHandle.Free(); base.Dispose(disposing); } } diff --git a/NetCord/Gateway/Voice/VoiceIdentifyProperties.cs b/NetCord/Gateway/Voice/VoiceIdentifyProperties.cs index 11151cd3f..72fd6c700 100644 --- a/NetCord/Gateway/Voice/VoiceIdentifyProperties.cs +++ b/NetCord/Gateway/Voice/VoiceIdentifyProperties.cs @@ -2,7 +2,7 @@ namespace NetCord.Gateway.Voice; -internal class VoiceIdentifyProperties(ulong guildId, ulong userId, string sessionId, string token) +internal class VoiceIdentifyProperties(ulong guildId, ulong userId, string sessionId, string token, int maxDaveProtocolVersion) { [JsonPropertyName("server_id")] public ulong GuildId { get; set; } = guildId; @@ -15,4 +15,7 @@ internal class VoiceIdentifyProperties(ulong guildId, ulong userId, string sessi [JsonPropertyName("token")] public string Token { get; set; } = token; + + [JsonPropertyName("max_dave_protocol_version")] + public int MaxDaveProtocolVersion { get; set; } = maxDaveProtocolVersion; } diff --git a/NetCord/Gateway/Voice/VoiceOpcode.cs b/NetCord/Gateway/Voice/VoiceOpcode.cs index e4dd01fc7..d7747a4bf 100644 --- a/NetCord/Gateway/Voice/VoiceOpcode.cs +++ b/NetCord/Gateway/Voice/VoiceOpcode.cs @@ -14,4 +14,15 @@ internal enum VoiceOpcode : byte Resumed = 9, ClientConnect = 11, ClientDisconnect = 13, + DavePrepareTransition = 21, + DaveExecuteTransition = 22, + DaveTransitionReady = 23, + DavePrepareEpoch = 24, + DaveMlsExternalSender = 25, + DaveMlsKeyPackage = 26, + DaveMlsProposals = 27, + DaveMlsCommitWelcome = 28, + DaveMlsAnnounceCommitTransition = 29, + DaveMlsWelcome = 30, + DaveMlsInvalidCommitWelcome = 31, } diff --git a/NetCord/Gateway/WebSocketClient.cs b/NetCord/Gateway/WebSocketClient.cs index 34aed6a42..a004cd4c2 100644 --- a/NetCord/Gateway/WebSocketClient.cs +++ b/NetCord/Gateway/WebSocketClient.cs @@ -30,7 +30,7 @@ private Retry() } } - private protected sealed class ConnectionState(IWebSocketConnection connection, IRateLimiter rateLimiter) : IDisposable + internal sealed class ConnectionState(IWebSocketConnection connection, IRateLimiter rateLimiter) : IDisposable { public IWebSocketConnection Connection => connection; @@ -189,7 +189,7 @@ private protected WebSocketClient(IWebSocketClientConfiguration configuration) _reconnectStrategy = configuration.ReconnectStrategy ?? new ReconnectStrategy(); _latencyTimer = configuration.LatencyTimer ?? new LatencyTimer(); _rateLimiterProvider = configuration.RateLimiterProvider ?? NullRateLimiterProvider.Instance; - _defaultPayloadProperties = CreatePayloadProperties(configuration.DefaultPayloadProperties); + _defaultTextPayloadProperties = CreatePayloadProperties(configuration.DefaultPayloadProperties); _logger = configuration.Logger ?? NullLogger.Instance; } @@ -221,8 +221,9 @@ public readonly InternalWebSocketPayloadProperties Compose(WebSocketPayloadPrope private readonly IWebSocketConnectionProvider _connectionProvider; private readonly IReconnectStrategy _reconnectStrategy; private readonly IRateLimiterProvider _rateLimiterProvider; - private readonly InternalWebSocketPayloadProperties _defaultPayloadProperties; - private protected readonly InternalWebSocketPayloadProperties _internalPayloadProperties = new(default, WebSocketMessageFlags.EndOfMessage, WebSocketRetryHandling.RetryRateLimit); + private readonly InternalWebSocketPayloadProperties _defaultTextPayloadProperties; + private protected readonly InternalWebSocketPayloadProperties _internalTextPayloadProperties = new(WebSocketMessageType.Text, WebSocketMessageFlags.EndOfMessage, WebSocketRetryHandling.RetryRateLimit); + private protected readonly InternalWebSocketPayloadProperties _internalBinaryPayloadProperties = new(WebSocketMessageType.Binary, WebSocketMessageFlags.EndOfMessage, WebSocketRetryHandling.RetryRateLimit); private protected readonly ILatencyTimer _latencyTimer; private protected readonly IWebSocketLogger _logger; @@ -378,11 +379,11 @@ private async void HandleClosed() await InvokeEventAsync(_close).ConfigureAwait(false); } - private async void HandleMessageReceived(State state, ConnectionState connectionState, ReadOnlyMemory data) + private async void HandleMessageReceived(State state, ConnectionState connectionState, WebSocketMessageType messageType, ReadOnlyMemory data) { try { - await ProcessPayloadAsync(state, connectionState, data.Span).ConfigureAwait(false); + await ProcessPayloadAsync(state, connectionState, messageType, data.Span).ConfigureAwait(false); } catch (Exception ex) { @@ -538,7 +539,7 @@ private async Task ReadAsync(State state, ConnectionState connectionState) break; writer.Advance(result.Count); - HandleMessageReceived(state, connectionState, writer.WrittenMemory); + HandleMessageReceived(state, connectionState, result.MessageType, writer.WrittenMemory); writer.Clear(); } else @@ -610,7 +611,7 @@ private protected ValueTask AbortAndResumeAsync(State state, ConnectionState con public async ValueTask SendPayloadAsync(ReadOnlyMemory buffer, WebSocketPayloadProperties? properties = null, CancellationToken cancellationToken = default) { - var payloadProperties = _defaultPayloadProperties.Compose(properties); + var payloadProperties = _defaultTextPayloadProperties.Compose(properties); while (true) { @@ -844,7 +845,7 @@ private protected async void StartHeartbeating(ConnectionState connectionState, private protected abstract ValueTask HeartbeatAsync(ConnectionState connectionState, CancellationToken cancellationToken = default); - private protected abstract ValueTask ProcessPayloadAsync(State state, ConnectionState connectionState, ReadOnlySpan payload); + private protected abstract ValueTask ProcessPayloadAsync(State state, ConnectionState connectionState, WebSocketMessageType messageType, ReadOnlySpan payload); private protected ValueTask UpdateLatencyAsync(TimeSpan latency) => InvokeEventAsync(_latencyUpdate, this, latency, static (client, latency) => Interlocked.Exchange(ref Unsafe.As(ref client._latency), Unsafe.As(ref latency))); diff --git a/NetCord/Serialization.cs b/NetCord/Serialization.cs index 1566176fa..75e0d6319 100644 --- a/NetCord/Serialization.cs +++ b/NetCord/Serialization.cs @@ -251,5 +251,10 @@ namespace NetCord; [JsonSerializable(typeof(JsonWebhookEventArgs))] [JsonSerializable(typeof(JsonRateLimitedEventArgs))] [JsonSerializable(typeof(JsonRequestGuildUsersRateLimitMetadata))] +[JsonSerializable(typeof(JsonDaveExecuteTransition))] +[JsonSerializable(typeof(JsonDavePrepareEpoch))] +[JsonSerializable(typeof(JsonDavePrepareTransition))] +[JsonSerializable(typeof(VoicePayloadProperties))] +[JsonSerializable(typeof(VoicePayloadProperties))] [JsonSerializable(typeof(JsonGuildMessagesSearchResult))] internal partial class Serialization : JsonSerializerContext; diff --git a/SourceGenerators/WebSocketClientEventsGenerator/WebSocketClientEventsGenerator.cs b/SourceGenerators/WebSocketClientEventsGenerator/WebSocketClientEventsGenerator.cs index d956903fa..99c19805d 100644 --- a/SourceGenerators/WebSocketClientEventsGenerator/WebSocketClientEventsGenerator.cs +++ b/SourceGenerators/WebSocketClientEventsGenerator/WebSocketClientEventsGenerator.cs @@ -19,9 +19,14 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { var typeSymbols = context.SyntaxProvider .CreateSyntaxProvider((node, cancellationToken) => node is ClassDeclarationSyntax { Identifier.Text: WebSocketClientName or GatewayClientName or VoiceClientName }, - (context, cancellationToken) => (INamedTypeSymbol)context.SemanticModel.GetDeclaredSymbol(context.Node)!); + (context, cancellationToken) => (INamedTypeSymbol)context.SemanticModel.GetDeclaredSymbol(context.Node)!) + .Collect(); - context.RegisterSourceOutput(typeSymbols, (context, source) => context.AddSource($"{source.Name}.g.cs", SourceText.From(GenerateEvents(source), Encoding.UTF8))); + context.RegisterSourceOutput(typeSymbols, (context, source) => + { + foreach (var typeSymbol in source.Distinct(SymbolEqualityComparer.Default).Cast()) + context.AddSource($"{typeSymbol.Name}.g.cs", SourceText.From(GenerateEvents(typeSymbol), Encoding.UTF8)); + }); } private string GenerateEvents(INamedTypeSymbol typeSymbol) diff --git a/Tests/NetCord.Test/ApplicationCommands/VoiceCommands.cs b/Tests/NetCord.Test/ApplicationCommands/VoiceCommands.cs index 7616ca773..5fedf978d 100644 --- a/Tests/NetCord.Test/ApplicationCommands/VoiceCommands.cs +++ b/Tests/NetCord.Test/ApplicationCommands/VoiceCommands.cs @@ -12,10 +12,15 @@ namespace NetCord.Test.ApplicationCommands; public class VoiceCommands(Dictionary joinSemaphores) : ApplicationCommandModule { - private async Task JoinAsync(VoiceEncryption? encryption, Func? disconnectHandler = null) + private async Task JoinAsync(IVoiceGuildChannel? channel, VoiceEncryption? encryption, Func? disconnectHandler = null) { var guild = Context.Guild!; - if (!guild.VoiceStates.TryGetValue(Context.User.Id, out var state)) + ulong channelId; + if (channel is not null) + channelId = channel.Id; + else if (guild.VoiceStates.TryGetValue(Context.User.Id, out var state)) + channelId = state.ChannelId.GetValueOrDefault(); + else throw new("You are not in a voice channel!"); var client = Context.Client; @@ -43,7 +48,7 @@ private async Task JoinAsync(VoiceEncryption? encryption, Func throw new InvalidEnumArgumentException(nameof(encryption), (int)encryption, typeof(VoiceEncryption)), }) : null; - voiceClient = await client.JoinVoiceChannelAsync(guild.Id, state.ChannelId.GetValueOrDefault(), new() + voiceClient = await client.JoinVoiceChannelAsync(guild.Id, channelId, new() { EncryptionProvider = encryptionProvider, ReceiveHandler = new VoiceReceiveHandler(), @@ -65,11 +70,11 @@ private async Task JoinAsync(VoiceEncryption? encryption, Func + var voiceClient = await JoinAsync(channel, encryption, args => { if (!args.Reconnect) cancellationTokenSource.Cancel(); @@ -83,32 +88,38 @@ public async Task PlayAsync(VoiceEncryption? encryption = null) var url = "https://www.mfiles.co.uk/mp3-downloads/beethoven-symphony6-1.mp3"; // 00:12:08 //var url = "https://file-examples.com/storage/feee5c69f0643c59da6bf13/2017/11/file_example_MP3_700KB.mp3"; // 00:00:27 await RespondAsync(InteractionCallback.Message($"Playing: {Path.GetFileNameWithoutExtension(url)}")); - using var ffmpeg = Process.Start(new ProcessStartInfo - { - FileName = "ffmpeg", - Arguments = $"-i \"{url}\" -ac 2 -f f32le -ar 48000 pipe:1", - RedirectStandardOutput = true, - })!; var token = cancellationTokenSource.Token; - try + do { - await ffmpeg.StandardOutput.BaseStream.CopyToAsync(opusEncodeStream, token); - await opusEncodeStream.FlushAsync(token); - await Task.Delay(-1, token); - } - catch (OperationCanceledException) - { - ffmpeg.Kill(); + using var ffmpeg = Process.Start(new ProcessStartInfo + { + FileName = "ffmpeg", + Arguments = $"-i \"{url}\" -ac 2 -f f32le -ar 48000 pipe:1", + RedirectStandardOutput = true, + })!; + try + { + await ffmpeg.StandardOutput.BaseStream.CopyToAsync(opusEncodeStream, token); + await opusEncodeStream.FlushAsync(token); + } + catch (OperationCanceledException) + { + ffmpeg.Kill(); + return; + } } + while (loop); + + await Task.Delay(-1, token); } [SlashCommand("echo", "Echo!")] - public async Task EchoAsync(VoiceEncryption? encryption = null) + public async Task EchoAsync(IVoiceGuildChannel? channel = null, VoiceEncryption? encryption = null) { TaskCompletionSource taskCompletionSource = new(); - var voiceClient = await JoinAsync(encryption, args => + var voiceClient = await JoinAsync(channel, encryption, args => { if (!args.Reconnect) taskCompletionSource.TrySetResult(); @@ -134,6 +145,39 @@ public async Task EchoAsync(VoiceEncryption? encryption = null) await taskCompletionSource.Task; } + [SlashCommand("record", "Record!")] + public async Task RecordAsync(IVoiceGuildChannel? channel = null, VoiceEncryption? encryption = null) + { + TaskCompletionSource taskCompletionSource = new(); + + using var ffmpeg = Process.Start(new ProcessStartInfo + { + FileName = "ffmpeg", + Arguments = $"-f s16le -ar 48000 -ac 2 -i pipe:0 recording-{DateTimeOffset.UtcNow.ToUnixTimeSeconds()}.wav", + RedirectStandardInput = true, + })!; + + var voiceClient = await JoinAsync(channel, encryption, args => + { + if (!args.Reconnect) + taskCompletionSource.TrySetResult(); + + return default; + }); + await RespondAsync(InteractionCallback.Message("Recording!")); + + OpusDecodeStream opusDecodeStream = new(ffmpeg.StandardInput.BaseStream, PcmFormat.Short, VoiceChannels.Stereo); + + voiceClient.VoiceReceive += args => + { + opusDecodeStream.Write(args.Frame); + return default; + }; + + await taskCompletionSource.Task; + ffmpeg.Kill(); + } + public enum VoiceEncryption : byte { XSalsa20Poly1305, diff --git a/Tests/NetCord.Test/NetCord.Test.csproj b/Tests/NetCord.Test/NetCord.Test.csproj index 4c051f79e..335ed7c71 100644 --- a/Tests/NetCord.Test/NetCord.Test.csproj +++ b/Tests/NetCord.Test/NetCord.Test.csproj @@ -21,6 +21,12 @@ + + PreserveNewest + + + PreserveNewest + PreserveNewest diff --git a/Tests/NetCord.Test/dave.dll b/Tests/NetCord.Test/dave.dll new file mode 100644 index 000000000..0642bb0b4 Binary files /dev/null and b/Tests/NetCord.Test/dave.dll differ