Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions DnsClientX.Tests/DnsMessageRecursionDesiredTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
using System;
using Xunit;

namespace DnsClientX.Tests {
/// <summary>
/// Tests for <see cref="DnsMessageOptions.RecursionDesired"/> handling in DNS wire serialization.
/// </summary>
public class DnsMessageRecursionDesiredTests {
private const ushort RdFlag = 0x0100;

/// <summary>
/// Ensures the RD bit is cleared when recursion is not desired.
/// </summary>
[Fact]
public void SerializeDnsWireFormat_ShouldClearRdBit_WhenRecursionDesiredFalse() {
var opts = new DnsMessageOptions(RecursionDesired: false);
var message = new DnsMessage("example.com", DnsRecordType.A, opts);

byte[] query = message.SerializeDnsWireFormat();

ushort flags = ReadFlags(query);
Assert.True((flags & RdFlag) == 0);
}

/// <summary>
/// Ensures the RD bit is set when recursion is desired.
/// </summary>
[Fact]
public void SerializeDnsWireFormat_ShouldSetRdBit_WhenRecursionDesiredTrue() {
var opts = new DnsMessageOptions(RecursionDesired: true);
var message = new DnsMessage("example.com", DnsRecordType.A, opts);

byte[] query = message.SerializeDnsWireFormat();

ushort flags = ReadFlags(query);
Assert.True((flags & RdFlag) != 0);
}

private static ushort ReadFlags(byte[] query) {
if (query == null) {
throw new ArgumentNullException(nameof(query));
}
if (query.Length < 4) {
throw new ArgumentException("DNS message is too short to contain a header.", nameof(query));
}

return (ushort)((query[2] << 8) | query[3]);
}
}
}
229 changes: 229 additions & 0 deletions DnsClientX.Tests/DnsWireMessageParserTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
using System;
using System.Collections.Generic;
using Xunit;

namespace DnsClientX.Tests {
/// <summary>
/// Tests for <see cref="DnsWireMessageParser"/> helpers.
/// </summary>
public class DnsWireMessageParserTests {
private const int DnsHeaderLength = 12;
private const ushort TcFlag = 0x0200;
private const ushort RdFlag = 0x0100;
private const ushort RaFlag = 0x0080;

/// <summary>
/// Ensures header parsing extracts flags and section counts.
/// </summary>
[Fact]
public void TryParseHeader_ShouldParseFlagsAndCounts() {
ushort flags = (ushort)(TcFlag | RdFlag | RaFlag | (ushort)DnsResponseCode.NXDomain);
byte[] data = CreateHeader(flags, qd: 1, an: 2, ns: 3, ar: 4);

bool ok = DnsWireMessageParser.TryParseHeader(data, out var header);

Assert.True(ok);
Assert.True(header.IsTruncated);
Assert.True(header.IsRecursionAvailable);
Assert.True(header.IsRecursionDesired);
Assert.Equal(DnsResponseCode.NXDomain, header.ResponseCode);
Assert.Equal((ushort)1, header.QuestionCount);
Assert.Equal((ushort)2, header.AnswerCount);
Assert.Equal((ushort)3, header.AuthorityCount);
Assert.Equal((ushort)4, header.AdditionalCount);
}

/// <summary>
/// Ensures header parsing fails for truncated data.
/// </summary>
[Fact]
public void TryParseHeader_ShouldReturnFalse_WhenDataTooShort() {
byte[] data = new byte[DnsHeaderLength - 1];

bool ok = DnsWireMessageParser.TryParseHeader(data, out var header);

Assert.False(ok);
Assert.Equal(default, header);
}

/// <summary>
/// Ensures EDNS parsing finds the OPT record and returns the UDP payload size.
/// </summary>
[Fact]
public void TryParseEdns_ShouldReturnUdpPayloadSize_WhenOptPresent() {
const ushort udpPayloadSize = 1232;
var opts = new DnsMessageOptions(EnableEdns: true, UdpBufferSize: udpPayloadSize);
var message = new DnsMessage("example.com", DnsRecordType.A, opts);

byte[] data = message.SerializeDnsWireFormat();

bool ok = DnsWireMessageParser.TryParseEdns(data, out var edns);

Assert.True(ok);
Assert.True(edns.Supported);
Assert.Equal(udpPayloadSize, edns.UdpPayloadSize);
}

/// <summary>
/// Ensures EDNS parsing returns a non-supported result when no OPT record is present.
/// </summary>
[Fact]
public void TryParseEdns_ShouldReturnUnsupported_WhenNoOptPresent() {
var opts = new DnsMessageOptions(EnableEdns: false);
var message = new DnsMessage("example.com", DnsRecordType.A, opts);

byte[] data = message.SerializeDnsWireFormat();

bool ok = DnsWireMessageParser.TryParseEdns(data, out var edns);

Assert.True(ok);
Assert.False(edns.Supported);
Assert.Equal(0, edns.UdpPayloadSize);
}

/// <summary>
/// Ensures EDNS parsing fails when the DNS message header is truncated.
/// </summary>
[Fact]
public void TryParseEdns_ShouldReturnFalse_WhenDataTooShort() {
byte[] data = new byte[DnsHeaderLength - 1];

bool ok = DnsWireMessageParser.TryParseEdns(data, out _);

Assert.False(ok);
}

/// <summary>
/// Ensures EDNS parsing fails when questions are declared but not present.
/// </summary>
[Fact]
public void TryParseEdns_ShouldReturnFalse_WhenQuestionIsTruncated() {
byte[] data = CreateHeader(flags: 0, qd: 1, an: 0, ns: 0, ar: 0);

bool ok = DnsWireMessageParser.TryParseEdns(data, out _);

Assert.False(ok);
}

/// <summary>
/// Ensures EDNS parsing fails for invalid label lengths (&gt; 63).
/// </summary>
[Fact]
public void TryParseEdns_ShouldReturnFalse_WhenQuestionLabelTooLong() {
byte[] header = CreateHeader(flags: 0, qd: 1, an: 0, ns: 0, ar: 0);
byte[] data = new byte[header.Length + 1];
Array.Copy(header, 0, data, 0, header.Length);
data[DnsHeaderLength] = 64;

bool ok = DnsWireMessageParser.TryParseEdns(data, out _);

Assert.False(ok);
}

/// <summary>
/// Ensures EDNS parsing can handle compressed names in the question section.
/// </summary>
[Fact]
public void TryParseEdns_ShouldHandleCompressedQuestionName() {
const ushort udpPayloadSize = 1232;
var bytes = new List<byte>();

byte[] header = CreateHeader(flags: 0, qd: 2, an: 0, ns: 0, ar: 1);
bytes.AddRange(header);

int firstQuestionNameOffset = bytes.Count;
AppendQuestion(bytes, "example.com", DnsRecordType.A, qclass: 1);

AppendCompressedQuestion(bytes, pointerOffset: (ushort)firstQuestionNameOffset, DnsRecordType.A, qclass: 1);

AppendOptRecord(bytes, udpPayloadSize);

bool ok = DnsWireMessageParser.TryParseEdns(bytes.ToArray(), out var edns);

Assert.True(ok);
Assert.True(edns.Supported);
Assert.Equal(udpPayloadSize, edns.UdpPayloadSize);
}

/// <summary>
/// Ensures EDNS parsing fails when a compression pointer is truncated.
/// </summary>
[Fact]
public void TryParseEdns_ShouldReturnFalse_WhenCompressionPointerTruncated() {
byte[] header = CreateHeader(flags: 0, qd: 1, an: 0, ns: 0, ar: 0);
byte[] data = new byte[header.Length + 1];
Array.Copy(header, 0, data, 0, header.Length);
data[DnsHeaderLength] = 0xC0;

bool ok = DnsWireMessageParser.TryParseEdns(data, out _);

Assert.False(ok);
}

private static byte[] CreateHeader(ushort flags, ushort qd, ushort an, ushort ns, ushort ar) {
var data = new byte[DnsHeaderLength];
WriteUInt16At(data, 0, 1); // ID
WriteUInt16At(data, 2, flags);
WriteUInt16At(data, 4, qd);
WriteUInt16At(data, 6, an);
WriteUInt16At(data, 8, ns);
WriteUInt16At(data, 10, ar);
return data;
}

private static void WriteUInt16At(byte[] buffer, int offset, ushort value) {
buffer[offset] = (byte)(value >> 8);
buffer[offset + 1] = (byte)(value & 0xFF);
}

private static void AppendQuestion(List<byte> message, string name, DnsRecordType qtype, ushort qclass) {
if (message == null) {
throw new ArgumentNullException(nameof(message));
}
if (string.IsNullOrWhiteSpace(name)) {
throw new ArgumentException("Name must not be empty.", nameof(name));
}

foreach (var label in name.TrimEnd('.').Split('.')) {
message.Add((byte)label.Length);
foreach (char c in label) {
message.Add((byte)c);
}
}
message.Add(0x00);

message.Add((byte)(((ushort)qtype >> 8) & 0xFF));
message.Add((byte)((ushort)qtype & 0xFF));
message.Add((byte)(qclass >> 8));
message.Add((byte)(qclass & 0xFF));
}

private static void AppendCompressedQuestion(List<byte> message, ushort pointerOffset, DnsRecordType qtype, ushort qclass) {
// Compression pointer: 11xx xxxx xxxx xxxx (14-bit offset).
byte first = (byte)(0xC0 | ((pointerOffset >> 8) & 0x3F));
byte second = (byte)(pointerOffset & 0xFF);

message.Add(first);
message.Add(second);

message.Add((byte)(((ushort)qtype >> 8) & 0xFF));
message.Add((byte)((ushort)qtype & 0xFF));
message.Add((byte)(qclass >> 8));
message.Add((byte)(qclass & 0xFF));
}

private static void AppendOptRecord(List<byte> message, ushort udpPayloadSize) {
message.Add(0x00); // root name
message.Add(0x00);
message.Add(0x29); // TYPE OPT
message.Add((byte)(udpPayloadSize >> 8));
message.Add((byte)(udpPayloadSize & 0xFF));
message.Add(0x00);
message.Add(0x00);
message.Add(0x00);
message.Add(0x00); // TTL
message.Add(0x00);
message.Add(0x00); // RDLEN
}
}
}
3 changes: 2 additions & 1 deletion DnsClientX/DnsMessageOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ public readonly record struct DnsMessageOptions(
EdnsClientSubnetOption? Subnet = null,
bool CheckingDisabled = false,
AsymmetricAlgorithm? SigningKey = null,
IEnumerable<EdnsOption>? Options = null);
IEnumerable<EdnsOption>? Options = null,
bool RecursionDesired = true);
8 changes: 6 additions & 2 deletions DnsClientX/ProtocolDnsWire/DnsMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace DnsClientX {
public class DnsMessage {
private readonly string _name;
private readonly DnsRecordType _type;
private readonly bool _recursionDesired;
private readonly bool _requestDnsSec;
private readonly bool _enableEdns;
private readonly int _udpBufferSize;
Expand Down Expand Up @@ -58,6 +59,7 @@ public DnsMessage(string name, DnsRecordType type, bool requestDnsSec, bool enab
public DnsMessage(string name, DnsRecordType type, DnsMessageOptions options) {
_name = name;
_type = type;
_recursionDesired = options.RecursionDesired;
_requestDnsSec = options.RequestDnsSec;
_ednsOptions = options.Options?.ToArray() ?? Array.Empty<EdnsOption>();
_enableEdns = options.EnableEdns || options.RequestDnsSec || options.Subnet != null || options.CheckingDisabled || _ednsOptions.Length > 0;
Expand Down Expand Up @@ -88,7 +90,8 @@ public string ToBase64Url() {
//stream.Write(buffer.ToArray(), 0, buffer.Length);

// Write the flags
BinaryPrimitives.WriteUInt16BigEndian(buffer, 0x0100);
ushort headerFlags = _recursionDesired ? (ushort)0x0100 : (ushort)0x0000;
BinaryPrimitives.WriteUInt16BigEndian(buffer, headerFlags);
stream.Write(buffer.ToArray(), 0, buffer.Length);

// Write the flags
Expand Down Expand Up @@ -216,7 +219,8 @@ public byte[] SerializeDnsWireFormat() {
ms.Write(bytes, 0, bytes.Length);

// Flags
bytes = BitConverter.GetBytes(IPAddress.HostToNetworkOrder((short)0x0100)); // Standard query
short headerFlags = _recursionDesired ? (short)0x0100 : (short)0x0000; // Standard query (RD optional)
bytes = BitConverter.GetBytes(IPAddress.HostToNetworkOrder(headerFlags));
ms.Write(bytes, 0, bytes.Length);

// Questions
Expand Down
Loading
Loading