diff --git a/DnsClientX.Tests/ResolveFromRootTests.cs b/DnsClientX.Tests/ResolveFromRootTests.cs
index 340dd454..4f13e641 100644
--- a/DnsClientX.Tests/ResolveFromRootTests.cs
+++ b/DnsClientX.Tests/ResolveFromRootTests.cs
@@ -1,3 +1,5 @@
+using System;
+using System.Collections.Generic;
using System.Threading.Tasks;
using Xunit;
@@ -17,5 +19,70 @@ public async Task ShouldResolveARecordFromRoot() {
Assert.Equal(DnsRecordType.A, ans.Type);
}
}
+
+ [Fact]
+ ///
+ /// Ensures root lookups created in parallel keep their clients alive until completion.
+ ///
+ public async Task QueryDns_MultipleRootLookups_DoesNotDisposeClientsEarly() {
+ var createdClients = new List();
+ var completions = new List>();
+ Func originalFactory = ClientX.RootClientFactory;
+ var originalResolver = ClientX.RootResolveOverride;
+
+ try {
+ ClientX.RootClientFactory = () => {
+ var client = new TrackingClientX();
+ createdClients.Add(client);
+ return client;
+ };
+
+ ClientX.RootResolveOverride = (client, name, recordType, cancellationToken) => {
+ var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ completions.Add(tcs);
+ return tcs.Task;
+ };
+
+ Task queryTask = ClientX.QueryDns(new[] { "example.com", "example.net" }, DnsRecordType.A, DnsEndpoint.RootServer);
+
+ await Task.Delay(10);
+ Assert.Equal(2, completions.Count);
+ Assert.All(createdClients, client => Assert.False(client.IsDisposed));
+
+ foreach (var completion in completions) {
+ completion.SetResult(new DnsResponse {
+ Answers = new[] {
+ new DnsAnswer {
+ Name = "example.com",
+ Type = DnsRecordType.A,
+ TTL = 60,
+ DataRaw = "127.0.0.1"
+ }
+ }
+ });
+ }
+
+ var responses = await queryTask;
+ Assert.Equal(2, responses.Length);
+ Assert.All(createdClients, client => Assert.True(client.IsDisposed));
+ } finally {
+ ClientX.RootClientFactory = originalFactory;
+ ClientX.RootResolveOverride = originalResolver;
+ }
+ }
+
+ private sealed class TrackingClientX : ClientX {
+ public bool IsDisposed { get; private set; }
+
+ protected override void Dispose(bool disposing) {
+ base.Dispose(disposing);
+ IsDisposed = true;
+ }
+
+ protected override async ValueTask DisposeAsyncCore() {
+ await base.DisposeAsyncCore().ConfigureAwait(false);
+ IsDisposed = true;
+ }
+ }
}
}
diff --git a/DnsClientX/AssemblyInfo.cs b/DnsClientX/AssemblyInfo.cs
index 84984ce6..09677f5a 100644
--- a/DnsClientX/AssemblyInfo.cs
+++ b/DnsClientX/AssemblyInfo.cs
@@ -1,2 +1,3 @@
using System.Runtime.CompilerServices;
+
[assembly: InternalsVisibleTo("DnsClientX.Tests")]
diff --git a/DnsClientX/DnsClientX.QueryDns.cs b/DnsClientX/DnsClientX.QueryDns.cs
index bdddc10c..77fae1cb 100644
--- a/DnsClientX/DnsClientX.QueryDns.cs
+++ b/DnsClientX/DnsClientX.QueryDns.cs
@@ -13,6 +13,10 @@ namespace DnsClientX {
/// Provides synchronous and asynchronous methods for performing DNS lookups.
///
public partial class ClientX {
+ internal static Func RootClientFactory { get; set; } = () => new ClientX();
+
+ internal static Func>? RootResolveOverride { get; set; }
+
///
/// Sends a DNS query for a specific record type to a DNS server.
/// This method allows you to specify the DNS endpoint from a predefined list of endpoints.
@@ -91,14 +95,27 @@ public static DnsResponse QueryDnsSync(string name, DnsRecordType recordType, Dn
/// A task that represents the asynchronous operation. The task result contains the DNS response.
public static async Task QueryDns(string[] name, DnsRecordType recordType, DnsEndpoint dnsEndpoint = DnsEndpoint.System, DnsSelectionStrategy dnsSelectionStrategy = DnsSelectionStrategy.First, int timeOutMilliseconds = Configuration.DefaultTimeout, bool retryOnTransient = true, int maxRetries = 3, int retryDelayMs = 200, bool requestDnsSec = false, bool validateDnsSec = false, bool typedRecords = false, bool parseTypedTxtRecords = false, CancellationToken cancellationToken = default) {
if (dnsEndpoint == DnsEndpoint.RootServer) {
- var tasks = name.Select(n => {
- using var client = new ClientX();
- if (cancellationToken.IsCancellationRequested) {
- return Task.FromCanceled(cancellationToken);
+ var clients = new List(name.Length);
+ try {
+ var tasks = name.Select(n => {
+ var client = RootClientFactory();
+ clients.Add(client);
+ if (cancellationToken.IsCancellationRequested) {
+ return Task.FromCanceled(cancellationToken);
+ }
+
+ var resolver = RootResolveOverride;
+ return resolver != null
+ ? resolver(client, n, recordType, cancellationToken)
+ : client.ResolveFromRoot(n, recordType, cancellationToken: cancellationToken);
+ }).ToArray();
+
+ return await Task.WhenAll(tasks).ConfigureAwait(false);
+ } finally {
+ foreach (var client in clients) {
+ client.Dispose();
}
- return client.ResolveFromRoot(n, recordType, cancellationToken: cancellationToken);
- });
- return await Task.WhenAll(tasks).ConfigureAwait(false);
+ }
} else {
using var client = new ClientX(endpoint: dnsEndpoint, dnsSelectionStrategy);
client.EndpointConfiguration.TimeOut = timeOutMilliseconds;