diff --git a/DnsClientX.Tests/ResolveFromRootTests.cs b/DnsClientX.Tests/ResolveFromRootTests.cs index 340dd454..4f13e641 100644 --- a/DnsClientX.Tests/ResolveFromRootTests.cs +++ b/DnsClientX.Tests/ResolveFromRootTests.cs @@ -1,3 +1,5 @@ +using System; +using System.Collections.Generic; using System.Threading.Tasks; using Xunit; @@ -17,5 +19,70 @@ public async Task ShouldResolveARecordFromRoot() { Assert.Equal(DnsRecordType.A, ans.Type); } } + + [Fact] + /// + /// Ensures root lookups created in parallel keep their clients alive until completion. + /// + public async Task QueryDns_MultipleRootLookups_DoesNotDisposeClientsEarly() { + var createdClients = new List(); + var completions = new List>(); + Func originalFactory = ClientX.RootClientFactory; + var originalResolver = ClientX.RootResolveOverride; + + try { + ClientX.RootClientFactory = () => { + var client = new TrackingClientX(); + createdClients.Add(client); + return client; + }; + + ClientX.RootResolveOverride = (client, name, recordType, cancellationToken) => { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + completions.Add(tcs); + return tcs.Task; + }; + + Task queryTask = ClientX.QueryDns(new[] { "example.com", "example.net" }, DnsRecordType.A, DnsEndpoint.RootServer); + + await Task.Delay(10); + Assert.Equal(2, completions.Count); + Assert.All(createdClients, client => Assert.False(client.IsDisposed)); + + foreach (var completion in completions) { + completion.SetResult(new DnsResponse { + Answers = new[] { + new DnsAnswer { + Name = "example.com", + Type = DnsRecordType.A, + TTL = 60, + DataRaw = "127.0.0.1" + } + } + }); + } + + var responses = await queryTask; + Assert.Equal(2, responses.Length); + Assert.All(createdClients, client => Assert.True(client.IsDisposed)); + } finally { + ClientX.RootClientFactory = originalFactory; + ClientX.RootResolveOverride = originalResolver; + } + } + + private sealed class TrackingClientX : ClientX { + public bool IsDisposed { get; private set; } + + protected override void Dispose(bool disposing) { + base.Dispose(disposing); + IsDisposed = true; + } + + protected override async ValueTask DisposeAsyncCore() { + await base.DisposeAsyncCore().ConfigureAwait(false); + IsDisposed = true; + } + } } } diff --git a/DnsClientX/AssemblyInfo.cs b/DnsClientX/AssemblyInfo.cs index 84984ce6..09677f5a 100644 --- a/DnsClientX/AssemblyInfo.cs +++ b/DnsClientX/AssemblyInfo.cs @@ -1,2 +1,3 @@ using System.Runtime.CompilerServices; + [assembly: InternalsVisibleTo("DnsClientX.Tests")] diff --git a/DnsClientX/DnsClientX.QueryDns.cs b/DnsClientX/DnsClientX.QueryDns.cs index bdddc10c..77fae1cb 100644 --- a/DnsClientX/DnsClientX.QueryDns.cs +++ b/DnsClientX/DnsClientX.QueryDns.cs @@ -13,6 +13,10 @@ namespace DnsClientX { /// Provides synchronous and asynchronous methods for performing DNS lookups. /// public partial class ClientX { + internal static Func RootClientFactory { get; set; } = () => new ClientX(); + + internal static Func>? RootResolveOverride { get; set; } + /// /// Sends a DNS query for a specific record type to a DNS server. /// This method allows you to specify the DNS endpoint from a predefined list of endpoints. @@ -91,14 +95,27 @@ public static DnsResponse QueryDnsSync(string name, DnsRecordType recordType, Dn /// A task that represents the asynchronous operation. The task result contains the DNS response. public static async Task QueryDns(string[] name, DnsRecordType recordType, DnsEndpoint dnsEndpoint = DnsEndpoint.System, DnsSelectionStrategy dnsSelectionStrategy = DnsSelectionStrategy.First, int timeOutMilliseconds = Configuration.DefaultTimeout, bool retryOnTransient = true, int maxRetries = 3, int retryDelayMs = 200, bool requestDnsSec = false, bool validateDnsSec = false, bool typedRecords = false, bool parseTypedTxtRecords = false, CancellationToken cancellationToken = default) { if (dnsEndpoint == DnsEndpoint.RootServer) { - var tasks = name.Select(n => { - using var client = new ClientX(); - if (cancellationToken.IsCancellationRequested) { - return Task.FromCanceled(cancellationToken); + var clients = new List(name.Length); + try { + var tasks = name.Select(n => { + var client = RootClientFactory(); + clients.Add(client); + if (cancellationToken.IsCancellationRequested) { + return Task.FromCanceled(cancellationToken); + } + + var resolver = RootResolveOverride; + return resolver != null + ? resolver(client, n, recordType, cancellationToken) + : client.ResolveFromRoot(n, recordType, cancellationToken: cancellationToken); + }).ToArray(); + + return await Task.WhenAll(tasks).ConfigureAwait(false); + } finally { + foreach (var client in clients) { + client.Dispose(); } - return client.ResolveFromRoot(n, recordType, cancellationToken: cancellationToken); - }); - return await Task.WhenAll(tasks).ConfigureAwait(false); + } } else { using var client = new ClientX(endpoint: dnsEndpoint, dnsSelectionStrategy); client.EndpointConfiguration.TimeOut = timeOutMilliseconds;