diff --git a/DnsClientX.Tests/DnsMessageRecursionDesiredTests.cs b/DnsClientX.Tests/DnsMessageRecursionDesiredTests.cs new file mode 100644 index 00000000..cd2c84f2 --- /dev/null +++ b/DnsClientX.Tests/DnsMessageRecursionDesiredTests.cs @@ -0,0 +1,50 @@ +using System; +using Xunit; + +namespace DnsClientX.Tests { + /// + /// Tests for handling in DNS wire serialization. + /// + public class DnsMessageRecursionDesiredTests { + private const ushort RdFlag = 0x0100; + + /// + /// Ensures the RD bit is cleared when recursion is not desired. + /// + [Fact] + public void SerializeDnsWireFormat_ShouldClearRdBit_WhenRecursionDesiredFalse() { + var opts = new DnsMessageOptions(RecursionDesired: false); + var message = new DnsMessage("example.com", DnsRecordType.A, opts); + + byte[] query = message.SerializeDnsWireFormat(); + + ushort flags = ReadFlags(query); + Assert.True((flags & RdFlag) == 0); + } + + /// + /// Ensures the RD bit is set when recursion is desired. + /// + [Fact] + public void SerializeDnsWireFormat_ShouldSetRdBit_WhenRecursionDesiredTrue() { + var opts = new DnsMessageOptions(RecursionDesired: true); + var message = new DnsMessage("example.com", DnsRecordType.A, opts); + + byte[] query = message.SerializeDnsWireFormat(); + + ushort flags = ReadFlags(query); + Assert.True((flags & RdFlag) != 0); + } + + private static ushort ReadFlags(byte[] query) { + if (query == null) { + throw new ArgumentNullException(nameof(query)); + } + if (query.Length < 4) { + throw new ArgumentException("DNS message is too short to contain a header.", nameof(query)); + } + + return (ushort)((query[2] << 8) | query[3]); + } + } +} diff --git a/DnsClientX.Tests/DnsWireMessageParserTests.cs b/DnsClientX.Tests/DnsWireMessageParserTests.cs new file mode 100644 index 00000000..e076cde4 --- /dev/null +++ b/DnsClientX.Tests/DnsWireMessageParserTests.cs @@ -0,0 +1,229 @@ +using System; +using System.Collections.Generic; +using Xunit; + +namespace DnsClientX.Tests { + /// + /// Tests for helpers. + /// + public class DnsWireMessageParserTests { + private const int DnsHeaderLength = 12; + private const ushort TcFlag = 0x0200; + private const ushort RdFlag = 0x0100; + private const ushort RaFlag = 0x0080; + + /// + /// Ensures header parsing extracts flags and section counts. + /// + [Fact] + public void TryParseHeader_ShouldParseFlagsAndCounts() { + ushort flags = (ushort)(TcFlag | RdFlag | RaFlag | (ushort)DnsResponseCode.NXDomain); + byte[] data = CreateHeader(flags, qd: 1, an: 2, ns: 3, ar: 4); + + bool ok = DnsWireMessageParser.TryParseHeader(data, out var header); + + Assert.True(ok); + Assert.True(header.IsTruncated); + Assert.True(header.IsRecursionAvailable); + Assert.True(header.IsRecursionDesired); + Assert.Equal(DnsResponseCode.NXDomain, header.ResponseCode); + Assert.Equal((ushort)1, header.QuestionCount); + Assert.Equal((ushort)2, header.AnswerCount); + Assert.Equal((ushort)3, header.AuthorityCount); + Assert.Equal((ushort)4, header.AdditionalCount); + } + + /// + /// Ensures header parsing fails for truncated data. + /// + [Fact] + public void TryParseHeader_ShouldReturnFalse_WhenDataTooShort() { + byte[] data = new byte[DnsHeaderLength - 1]; + + bool ok = DnsWireMessageParser.TryParseHeader(data, out var header); + + Assert.False(ok); + Assert.Equal(default, header); + } + + /// + /// Ensures EDNS parsing finds the OPT record and returns the UDP payload size. + /// + [Fact] + public void TryParseEdns_ShouldReturnUdpPayloadSize_WhenOptPresent() { + const ushort udpPayloadSize = 1232; + var opts = new DnsMessageOptions(EnableEdns: true, UdpBufferSize: udpPayloadSize); + var message = new DnsMessage("example.com", DnsRecordType.A, opts); + + byte[] data = message.SerializeDnsWireFormat(); + + bool ok = DnsWireMessageParser.TryParseEdns(data, out var edns); + + Assert.True(ok); + Assert.True(edns.Supported); + Assert.Equal(udpPayloadSize, edns.UdpPayloadSize); + } + + /// + /// Ensures EDNS parsing returns a non-supported result when no OPT record is present. + /// + [Fact] + public void TryParseEdns_ShouldReturnUnsupported_WhenNoOptPresent() { + var opts = new DnsMessageOptions(EnableEdns: false); + var message = new DnsMessage("example.com", DnsRecordType.A, opts); + + byte[] data = message.SerializeDnsWireFormat(); + + bool ok = DnsWireMessageParser.TryParseEdns(data, out var edns); + + Assert.True(ok); + Assert.False(edns.Supported); + Assert.Equal(0, edns.UdpPayloadSize); + } + + /// + /// Ensures EDNS parsing fails when the DNS message header is truncated. + /// + [Fact] + public void TryParseEdns_ShouldReturnFalse_WhenDataTooShort() { + byte[] data = new byte[DnsHeaderLength - 1]; + + bool ok = DnsWireMessageParser.TryParseEdns(data, out _); + + Assert.False(ok); + } + + /// + /// Ensures EDNS parsing fails when questions are declared but not present. + /// + [Fact] + public void TryParseEdns_ShouldReturnFalse_WhenQuestionIsTruncated() { + byte[] data = CreateHeader(flags: 0, qd: 1, an: 0, ns: 0, ar: 0); + + bool ok = DnsWireMessageParser.TryParseEdns(data, out _); + + Assert.False(ok); + } + + /// + /// Ensures EDNS parsing fails for invalid label lengths (> 63). + /// + [Fact] + public void TryParseEdns_ShouldReturnFalse_WhenQuestionLabelTooLong() { + byte[] header = CreateHeader(flags: 0, qd: 1, an: 0, ns: 0, ar: 0); + byte[] data = new byte[header.Length + 1]; + Array.Copy(header, 0, data, 0, header.Length); + data[DnsHeaderLength] = 64; + + bool ok = DnsWireMessageParser.TryParseEdns(data, out _); + + Assert.False(ok); + } + + /// + /// Ensures EDNS parsing can handle compressed names in the question section. + /// + [Fact] + public void TryParseEdns_ShouldHandleCompressedQuestionName() { + const ushort udpPayloadSize = 1232; + var bytes = new List(); + + byte[] header = CreateHeader(flags: 0, qd: 2, an: 0, ns: 0, ar: 1); + bytes.AddRange(header); + + int firstQuestionNameOffset = bytes.Count; + AppendQuestion(bytes, "example.com", DnsRecordType.A, qclass: 1); + + AppendCompressedQuestion(bytes, pointerOffset: (ushort)firstQuestionNameOffset, DnsRecordType.A, qclass: 1); + + AppendOptRecord(bytes, udpPayloadSize); + + bool ok = DnsWireMessageParser.TryParseEdns(bytes.ToArray(), out var edns); + + Assert.True(ok); + Assert.True(edns.Supported); + Assert.Equal(udpPayloadSize, edns.UdpPayloadSize); + } + + /// + /// Ensures EDNS parsing fails when a compression pointer is truncated. + /// + [Fact] + public void TryParseEdns_ShouldReturnFalse_WhenCompressionPointerTruncated() { + byte[] header = CreateHeader(flags: 0, qd: 1, an: 0, ns: 0, ar: 0); + byte[] data = new byte[header.Length + 1]; + Array.Copy(header, 0, data, 0, header.Length); + data[DnsHeaderLength] = 0xC0; + + bool ok = DnsWireMessageParser.TryParseEdns(data, out _); + + Assert.False(ok); + } + + private static byte[] CreateHeader(ushort flags, ushort qd, ushort an, ushort ns, ushort ar) { + var data = new byte[DnsHeaderLength]; + WriteUInt16At(data, 0, 1); // ID + WriteUInt16At(data, 2, flags); + WriteUInt16At(data, 4, qd); + WriteUInt16At(data, 6, an); + WriteUInt16At(data, 8, ns); + WriteUInt16At(data, 10, ar); + return data; + } + + private static void WriteUInt16At(byte[] buffer, int offset, ushort value) { + buffer[offset] = (byte)(value >> 8); + buffer[offset + 1] = (byte)(value & 0xFF); + } + + private static void AppendQuestion(List message, string name, DnsRecordType qtype, ushort qclass) { + if (message == null) { + throw new ArgumentNullException(nameof(message)); + } + if (string.IsNullOrWhiteSpace(name)) { + throw new ArgumentException("Name must not be empty.", nameof(name)); + } + + foreach (var label in name.TrimEnd('.').Split('.')) { + message.Add((byte)label.Length); + foreach (char c in label) { + message.Add((byte)c); + } + } + message.Add(0x00); + + message.Add((byte)(((ushort)qtype >> 8) & 0xFF)); + message.Add((byte)((ushort)qtype & 0xFF)); + message.Add((byte)(qclass >> 8)); + message.Add((byte)(qclass & 0xFF)); + } + + private static void AppendCompressedQuestion(List message, ushort pointerOffset, DnsRecordType qtype, ushort qclass) { + // Compression pointer: 11xx xxxx xxxx xxxx (14-bit offset). + byte first = (byte)(0xC0 | ((pointerOffset >> 8) & 0x3F)); + byte second = (byte)(pointerOffset & 0xFF); + + message.Add(first); + message.Add(second); + + message.Add((byte)(((ushort)qtype >> 8) & 0xFF)); + message.Add((byte)((ushort)qtype & 0xFF)); + message.Add((byte)(qclass >> 8)); + message.Add((byte)(qclass & 0xFF)); + } + + private static void AppendOptRecord(List message, ushort udpPayloadSize) { + message.Add(0x00); // root name + message.Add(0x00); + message.Add(0x29); // TYPE OPT + message.Add((byte)(udpPayloadSize >> 8)); + message.Add((byte)(udpPayloadSize & 0xFF)); + message.Add(0x00); + message.Add(0x00); + message.Add(0x00); + message.Add(0x00); // TTL + message.Add(0x00); + message.Add(0x00); // RDLEN + } + } +} diff --git a/DnsClientX/DnsMessageOptions.cs b/DnsClientX/DnsMessageOptions.cs index 496b4fc0..b622aa0b 100644 --- a/DnsClientX/DnsMessageOptions.cs +++ b/DnsClientX/DnsMessageOptions.cs @@ -13,4 +13,5 @@ public readonly record struct DnsMessageOptions( EdnsClientSubnetOption? Subnet = null, bool CheckingDisabled = false, AsymmetricAlgorithm? SigningKey = null, - IEnumerable? Options = null); + IEnumerable? Options = null, + bool RecursionDesired = true); diff --git a/DnsClientX/ProtocolDnsWire/DnsMessage.cs b/DnsClientX/ProtocolDnsWire/DnsMessage.cs index 84cee5ab..06b03c65 100644 --- a/DnsClientX/ProtocolDnsWire/DnsMessage.cs +++ b/DnsClientX/ProtocolDnsWire/DnsMessage.cs @@ -14,6 +14,7 @@ namespace DnsClientX { public class DnsMessage { private readonly string _name; private readonly DnsRecordType _type; + private readonly bool _recursionDesired; private readonly bool _requestDnsSec; private readonly bool _enableEdns; private readonly int _udpBufferSize; @@ -58,6 +59,7 @@ public DnsMessage(string name, DnsRecordType type, bool requestDnsSec, bool enab public DnsMessage(string name, DnsRecordType type, DnsMessageOptions options) { _name = name; _type = type; + _recursionDesired = options.RecursionDesired; _requestDnsSec = options.RequestDnsSec; _ednsOptions = options.Options?.ToArray() ?? Array.Empty(); _enableEdns = options.EnableEdns || options.RequestDnsSec || options.Subnet != null || options.CheckingDisabled || _ednsOptions.Length > 0; @@ -88,7 +90,8 @@ public string ToBase64Url() { //stream.Write(buffer.ToArray(), 0, buffer.Length); // Write the flags - BinaryPrimitives.WriteUInt16BigEndian(buffer, 0x0100); + ushort headerFlags = _recursionDesired ? (ushort)0x0100 : (ushort)0x0000; + BinaryPrimitives.WriteUInt16BigEndian(buffer, headerFlags); stream.Write(buffer.ToArray(), 0, buffer.Length); // Write the flags @@ -216,7 +219,8 @@ public byte[] SerializeDnsWireFormat() { ms.Write(bytes, 0, bytes.Length); // Flags - bytes = BitConverter.GetBytes(IPAddress.HostToNetworkOrder((short)0x0100)); // Standard query + short headerFlags = _recursionDesired ? (short)0x0100 : (short)0x0000; // Standard query (RD optional) + bytes = BitConverter.GetBytes(IPAddress.HostToNetworkOrder(headerFlags)); ms.Write(bytes, 0, bytes.Length); // Questions diff --git a/DnsClientX/ProtocolDnsWire/DnsWireMessageParser.cs b/DnsClientX/ProtocolDnsWire/DnsWireMessageParser.cs new file mode 100644 index 00000000..ab30c0c6 --- /dev/null +++ b/DnsClientX/ProtocolDnsWire/DnsWireMessageParser.cs @@ -0,0 +1,210 @@ +using System; + +namespace DnsClientX { + /// + /// Represents a parsed DNS message header (wire format). + /// + public readonly record struct DnsWireHeaderInfo( + bool IsTruncated, + bool IsRecursionAvailable, + bool IsRecursionDesired, + DnsResponseCode ResponseCode, + ushort QuestionCount, + ushort AnswerCount, + ushort AuthorityCount, + ushort AdditionalCount); + + /// + /// Represents basic EDNS (OPT) information extracted from a DNS message (wire format). + /// + public readonly record struct DnsWireEdnsInfo(bool Supported, int UdpPayloadSize); + + /// + /// Provides lightweight, safe parsing helpers for DNS wire-format messages. + /// + public static class DnsWireMessageParser { + private const int DnsHeaderLength = 12; + private const int QuestionTypeAndClassLength = 4; + private const int MaxNameSegmentsToSkip = 50; + private const int FlagsOffset = 2; + private const int QuestionCountOffset = 4; + private const int AnswerCountOffset = 6; + private const int AuthorityCountOffset = 8; + private const int AdditionalCountOffset = 10; + private const ushort TcFlag = 0x0200; + private const ushort RdFlag = 0x0100; + private const ushort RaFlag = 0x0080; + private const ushort RcodeMask = 0x000F; + private const byte CompressionPointerMask = 0xC0; + private const byte CompressionPointerValue = 0xC0; + private const byte MaxLabelLength = 63; + + /// + /// Attempts to parse the DNS message header (flags and section counts). + /// + /// Raw DNS message bytes. + /// Parsed header fields. + /// true when the header was parsed; otherwise false. + public static bool TryParseHeader(byte[]? data, out DnsWireHeaderInfo header) { + header = default; + if (data == null || data.Length < DnsHeaderLength) { + return false; + } + + ushort flags = ReadUInt16At(data, FlagsOffset); + ushort qd = ReadUInt16At(data, QuestionCountOffset); + ushort an = ReadUInt16At(data, AnswerCountOffset); + ushort ns = ReadUInt16At(data, AuthorityCountOffset); + ushort ar = ReadUInt16At(data, AdditionalCountOffset); + + header = new DnsWireHeaderInfo( + IsTruncated: (flags & TcFlag) != 0, + IsRecursionAvailable: (flags & RaFlag) != 0, + IsRecursionDesired: (flags & RdFlag) != 0, + ResponseCode: (DnsResponseCode)(flags & RcodeMask), + QuestionCount: qd, + AnswerCount: an, + AuthorityCount: ns, + AdditionalCount: ar); + + return true; + } + + /// + /// Attempts to locate an EDNS OPT record and extract the advertised UDP payload size. + /// + /// Raw DNS message bytes. + /// Parsed EDNS info. + /// + /// true when parsing succeeded (even if EDNS is not present); otherwise false for malformed messages. + /// + public static bool TryParseEdns(byte[]? data, out DnsWireEdnsInfo edns) { + edns = default; + if (data == null || data.Length < DnsHeaderLength) { + return false; + } + + int offset = QuestionCountOffset; + if (!TryReadUInt16(data, ref offset, out var qd)) { + return false; + } + if (!TryReadUInt16(data, ref offset, out var an)) { + return false; + } + if (!TryReadUInt16(data, ref offset, out var ns)) { + return false; + } + if (!TryReadUInt16(data, ref offset, out var ar)) { + return false; + } + + offset = DnsHeaderLength; + for (int i = 0; i < qd; i++) { + if (!TrySkipName(data, ref offset)) { + return false; + } + if (offset + QuestionTypeAndClassLength > data.Length) { + return false; + } + offset += QuestionTypeAndClassLength; // QTYPE + QCLASS + } + + int rrCount = an + ns + ar; + for (int i = 0; i < rrCount; i++) { + if (!TrySkipName(data, ref offset)) { + return false; + } + if (!TryReadUInt16(data, ref offset, out var type)) { + return false; + } + if (!TryReadUInt16(data, ref offset, out var rrClass)) { + return false; + } + if (!TryReadUInt32(data, ref offset, out _)) { + return false; + } + if (!TryReadUInt16(data, ref offset, out var rdlen)) { + return false; + } + if (offset + rdlen > data.Length) { + return false; + } + + if (type == (ushort)DnsRecordType.OPT) { + edns = new DnsWireEdnsInfo(Supported: true, UdpPayloadSize: rrClass); + return true; + } + + offset += rdlen; + } + + edns = new DnsWireEdnsInfo(Supported: false, UdpPayloadSize: 0); + return true; + } + + private static bool TrySkipName(byte[] buffer, ref int offset) { + int segments = 0; + while (true) { + if (buffer == null || offset < 0 || offset >= buffer.Length) { + return false; + } + + var len = buffer[offset++]; + if (len == 0) { + return true; + } + + // Compression pointer (RFC 1035 4.1.4): 2 bytes total. + if ((len & CompressionPointerMask) == CompressionPointerValue) { + if (offset >= buffer.Length) { + return false; + } + offset++; + return true; + } + + if (len > MaxLabelLength) { + return false; + } + + if (offset + len > buffer.Length) { + return false; + } + + offset += len; + if (++segments > MaxNameSegmentsToSkip) { + return false; + } + } + } + + private static bool TryReadUInt16(byte[] buffer, ref int offset, out ushort value) { + value = 0; + if (buffer == null || offset < 0 || offset + 2 > buffer.Length) { + return false; + } + + value = ReadUInt16At(buffer, offset); + offset += 2; + return true; + } + + private static bool TryReadUInt32(byte[] buffer, ref int offset, out uint value) { + value = 0; + if (buffer == null || offset < 0 || offset + 4 > buffer.Length) { + return false; + } + + value = ((uint)buffer[offset] << 24) + | ((uint)buffer[offset + 1] << 16) + | ((uint)buffer[offset + 2] << 8) + | buffer[offset + 3]; + offset += 4; + return true; + } + + private static ushort ReadUInt16At(byte[] buffer, int offset) { + return (ushort)(((ushort)buffer[offset] << 8) | buffer[offset + 1]); + } + } +} diff --git a/DnsClientX/Throttling/AsyncIntervalGate.cs b/DnsClientX/Throttling/AsyncIntervalGate.cs new file mode 100644 index 00000000..6477c4fc --- /dev/null +++ b/DnsClientX/Throttling/AsyncIntervalGate.cs @@ -0,0 +1,63 @@ +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace DnsClientX.Throttling; + +/// +/// Provides a simple async gate that ensures callers do not proceed more often than a configured interval. +/// +/// +/// This is useful for rate limiting DNS queries across concurrent tasks. +/// +public sealed class AsyncIntervalGate : IDisposable +{ + private readonly TimeSpan _interval; + private readonly SemaphoreSlim _mutex = new(1, 1); + private DateTime _nextUtc; + + /// + /// Initializes a new instance of the class. + /// + /// Minimum interval between permits; negative values are treated as zero. + public AsyncIntervalGate(TimeSpan interval) + { + _interval = interval < TimeSpan.Zero ? TimeSpan.Zero : interval; + _nextUtc = DateTime.MinValue; + } + + /// + /// Waits until the next permit is available and then consumes it. + /// + /// Token used to cancel the wait. + public async Task WaitAsync(CancellationToken cancellationToken = default) + { + await _mutex.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + var now = DateTime.UtcNow; + if (_nextUtc > now) + { + var delay = _nextUtc - now; + if (delay > TimeSpan.Zero) + { + await Task.Delay(delay, cancellationToken).ConfigureAwait(false); + } + now = DateTime.UtcNow; + } + + _nextUtc = now + _interval; + } + finally + { + _mutex.Release(); + } + } + + /// + public void Dispose() + { + _mutex.Dispose(); + } +} +