diff --git a/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs b/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs index 5bedc17d253f85..62a389ffe4ba5b 100644 --- a/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs +++ b/src/libraries/System.Net.NameResolution/src/System/Net/Dns.cs @@ -8,6 +8,7 @@ using System.Threading; using System.Threading.Tasks; using System.Runtime.Versioning; +using System.Diagnostics.CodeAnalysis; namespace System.Net { @@ -386,10 +387,45 @@ private static IPHostEntry GetHostEntryCore(string hostName, AddressFamily addre private static IPAddress[] GetHostAddressesCore(string hostName, AddressFamily addressFamily, NameResolutionActivity? activityOrDefault = default) => (IPAddress[])GetHostEntryOrAddressesCore(hostName, justAddresses: true, addressFamily, activityOrDefault); + private static bool ValidateAddressFamily(ref AddressFamily addressFamily, string hostName, bool justAddresses, [NotNullWhen(false)] out object? resultOnFailure) + { + if (!SocketProtocolSupportPal.OSSupportsIPv6) + { + if (addressFamily == AddressFamily.InterNetworkV6) + { + // The caller requested IPv6, but the OS doesn't support it; return an empty result. + IPAddress[] addresses = Array.Empty(); + resultOnFailure = justAddresses ? (object) + addresses : + new IPHostEntry + { + AddressList = addresses, + HostName = hostName, + Aliases = Array.Empty() + }; + return false; + } + else if (addressFamily == AddressFamily.Unspecified) + { + // Narrow the query to IPv4. + addressFamily = AddressFamily.InterNetwork; + } + } + + resultOnFailure = null; + return true; + } + private static object GetHostEntryOrAddressesCore(string hostName, bool justAddresses, AddressFamily addressFamily, NameResolutionActivity? activityOrDefault = default) { ValidateHostName(hostName); + if (!ValidateAddressFamily(ref addressFamily, hostName, justAddresses, out object? resultOnFailure)) + { + Debug.Assert(!activityOrDefault.HasValue); + return resultOnFailure; + } + // NameResolutionActivity may have already been set if we're being called from RunAsync. NameResolutionActivity activity = activityOrDefault ?? NameResolutionTelemetry.Log.BeforeResolution(hostName); @@ -463,6 +499,11 @@ private static object GetHostEntryOrAddressesCore(IPAddress address, bool justAd NameResolutionTelemetry.Log.AfterResolution(address, activity, answer: name); + if (!ValidateAddressFamily(ref addressFamily, name, justAddresses, out object? resultOnFailure)) + { + return resultOnFailure; + } + // Do the forward lookup to get the IPs for that host name activity = NameResolutionTelemetry.Log.BeforeResolution(name); @@ -518,6 +559,13 @@ private static Task GetHostEntryOrAddressesCoreAsync(string hostName, bool justR Task.FromCanceled(cancellationToken); } + if (!ValidateAddressFamily(ref family, hostName, justAddresses, out object? resultOnFailure)) + { + return justAddresses ? (Task) + Task.FromResult((IPAddress[])resultOnFailure) : + Task.FromResult((IPHostEntry)resultOnFailure); + } + object asyncState; // See if it's an IP Address. diff --git a/src/libraries/System.Net.NameResolution/tests/FunctionalTests/ActivityTest.cs b/src/libraries/System.Net.NameResolution/tests/FunctionalTests/ActivityTest.cs index 13dfeb908942c5..c9c0e1b8d997b3 100644 --- a/src/libraries/System.Net.NameResolution/tests/FunctionalTests/ActivityTest.cs +++ b/src/libraries/System.Net.NameResolution/tests/FunctionalTests/ActivityTest.cs @@ -114,7 +114,7 @@ public static async Task ForwardLookup_InvalidHostName_ActivityRecorded(bool cre { const string InvalidHostName = $"invalid...example.com...{nameof(ForwardLookup_InvalidHostName_ActivityRecorded)}"; - await RemoteExecutor.Invoke(async (createParentActivity) => + await RemoteExecutor.Invoke(static async (createParentActivity) => { using var recorder = new ActivityRecorder(ActivitySourceName, ActivityName) { @@ -151,6 +151,26 @@ void Verify(int timesLookupRecorded) }, createParentActivity.ToString()).DisposeAsync(); } + [ConditionalFact(typeof(GetHostEntryTest), nameof(GetHostEntryTest.GetHostEntry_DisableIPv6_Condition))] + public static void ForwardLookup_DisableIPv6_AddressFamilyInterNetworkV6_ActivitiesAreFinished() + { + RemoteExecutor.Invoke(static async () => + { + const string ValidHostName = "localhost"; + AppContext.SetSwitch("System.Net.DisableIPv6", true); + using var recorder = new ActivityRecorder(ActivitySourceName, ActivityName); + + await Dns.GetHostEntryAsync(ValidHostName); + await Dns.GetHostAddressesAsync(ValidHostName); + Dns.GetHostEntry(ValidHostName); + Dns.GetHostAddresses(ValidHostName); + Dns.EndGetHostEntry(Dns.BeginGetHostEntry(ValidHostName, null, null)); + Dns.EndGetHostAddresses(Dns.BeginGetHostAddresses(ValidHostName, null, null)); + + Assert.Equal(recorder.Started, recorder.Stopped); + }).Dispose(); + } + static void VerifyForwardActivityInfo(Activity activity, string question) { Assert.Equal(ActivityKind.Internal, activity.Kind); diff --git a/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostAddressesTest.cs b/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostAddressesTest.cs index d1073db5d3986a..99421451491714 100644 --- a/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostAddressesTest.cs +++ b/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostAddressesTest.cs @@ -6,7 +6,7 @@ using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; - +using Microsoft.DotNet.RemoteExecutor; using Xunit; namespace System.Net.NameResolution.Tests @@ -171,6 +171,39 @@ public async Task DnsGetHostAddresses_PreCancelledToken_Throws() OperationCanceledException oce = await Assert.ThrowsAnyAsync(() => Dns.GetHostAddressesAsync(TestSettings.LocalHost, cts.Token)); Assert.Equal(cts.Token, oce.CancellationToken); } + + [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + [InlineData(false)] + [InlineData(true)] + public void GetHostAddresses_DisableIPv6_ExcludesIPv6Addresses(bool useAsyncOuter) + { + RemoteExecutor.Invoke(RunTest, useAsyncOuter.ToString()).Dispose(); + + static async Task RunTest(string useAsync) + { + AppContext.SetSwitch("System.Net.DisableIPv6", true); + IPAddress[] addresses = + bool.Parse(useAsync) ? await Dns.GetHostAddressesAsync(TestSettings.LocalHost) : + Dns.GetHostAddresses(TestSettings.LocalHost); + Assert.All(addresses, address => Assert.Equal(AddressFamily.InterNetwork, address.AddressFamily)); + } + } + + [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + [InlineData(false)] + [InlineData(true)] + public void GetHostAddresses_DisableIPv6_AddressFamilyInterNetworkV6_ReturnsEmpty(bool useAsyncOuter) + { + RemoteExecutor.Invoke(RunTest, useAsyncOuter.ToString()).Dispose(); + static async Task RunTest(string useAsync) + { + AppContext.SetSwitch("System.Net.DisableIPv6", true); + IPAddress[] addresses = + bool.Parse(useAsync) ? await Dns.GetHostAddressesAsync(TestSettings.LocalHost, AddressFamily.InterNetworkV6) : + Dns.GetHostAddresses(TestSettings.LocalHost, AddressFamily.InterNetworkV6); + Assert.Empty(addresses); + } + } } // Cancellation tests are sequential to reduce the chance of timing issues. diff --git a/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostEntryTest.cs b/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostEntryTest.cs index c4988cef816725..913a0a3412dbf2 100644 --- a/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostEntryTest.cs +++ b/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostEntryTest.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.IO; +using System.Linq; using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; @@ -108,38 +109,41 @@ private static async Task TestGetHostEntryAsync(Func> getHostE public static bool GetHostEntry_DisableIPv6_Condition = GetHostEntryWorks && RemoteExecutor.IsSupported; [ConditionalTheory(nameof(GetHostEntry_DisableIPv6_Condition))] - [InlineData("")] - [InlineData(TestSettings.LocalHost)] - public void Dns_GetHostEntry_DisableIPv6_ExcludesIPv6Addresses(string hostnameOuter) + [InlineData("", false)] + [InlineData("", true)] + [InlineData(TestSettings.LocalHost, false)] + [InlineData(TestSettings.LocalHost, true)] + public void GetHostEntry_DisableIPv6_ExcludesIPv6Addresses(string hostnameOuter, bool useAsyncOuter) { - RemoteExecutor.Invoke(RunTest, hostnameOuter).Dispose(); + string expectedHostName = Dns.GetHostEntry(hostnameOuter).HostName; + RemoteExecutor.Invoke(RunTest, hostnameOuter, expectedHostName, useAsyncOuter.ToString()).Dispose(); - static void RunTest(string hostnameInner) + static async Task RunTest(string hostnameInner, string expectedHostName, string useAsync) { AppContext.SetSwitch("System.Net.DisableIPv6", true); - IPHostEntry entry = Dns.GetHostEntry(hostnameInner); - foreach (IPAddress address in entry.AddressList) - { - Assert.NotEqual(AddressFamily.InterNetworkV6, address.AddressFamily); - } + + IPHostEntry entry = bool.Parse(useAsync) ? + await Dns.GetHostEntryAsync(hostnameInner) : + Dns.GetHostEntry(hostnameInner); + + Assert.Equal(entry.HostName, expectedHostName); + Assert.All(entry.AddressList, address => Assert.Equal(AddressFamily.InterNetwork, address.AddressFamily)); } } [ConditionalTheory(nameof(GetHostEntry_DisableIPv6_Condition))] - [InlineData("")] - [InlineData(TestSettings.LocalHost)] - public void Dns_GetHostEntryAsync_DisableIPv6_ExcludesIPv6Addresses(string hostnameOuter) + [InlineData(false)] + [InlineData(true)] + public void GetHostEntry_DisableIPv6_AddressFamilyInterNetworkV6_ReturnsEmpty(bool useAsyncOuter) { - RemoteExecutor.Invoke(RunTest, hostnameOuter).Dispose(); - - static async Task RunTest(string hostnameInner) + RemoteExecutor.Invoke(RunTest, useAsyncOuter.ToString()).Dispose(); + static async Task RunTest(string useAsync) { AppContext.SetSwitch("System.Net.DisableIPv6", true); - IPHostEntry entry = await Dns.GetHostEntryAsync(hostnameInner); - foreach (IPAddress address in entry.AddressList) - { - Assert.NotEqual(AddressFamily.InterNetworkV6, address.AddressFamily); - } + IPHostEntry entry = bool.Parse(useAsync) ? + await Dns.GetHostEntryAsync(TestSettings.LocalHost, AddressFamily.InterNetworkV6) : + Dns.GetHostEntry(TestSettings.LocalHost, AddressFamily.InterNetworkV6); + Assert.Empty(entry.AddressList); } }