Skip to content

Add MultiCredentialSecurityTokenManager to handle CoreWCF service certificates #5802

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,6 @@ public static void DuplexCallback_Throws_FaultException_ReturnsFaultedTask()
}

[WcfFact]
[Condition(nameof(Skip_CoreWCFService_FailedTest))]
[OuterLoop]
// Verify product throws MessageSecurityException when the Dns identity from the server does not match the expectation
public static void TCP_ServiceCertExpired_Throw_MessageSecurityException()
Expand Down Expand Up @@ -373,7 +372,6 @@ public static void TCP_ServiceCertExpired_Throw_MessageSecurityException()
}

[WcfFact]
[Condition(nameof(Skip_CoreWCFService_FailedTest))]
[OuterLoop]
// Verify product throws SecurityNegotiationException when the service cert is revoked
public static void TCP_ServiceCertRevoked_Throw_SecurityNegotiationException()
Expand Down Expand Up @@ -421,7 +419,6 @@ public static void TCP_ServiceCertRevoked_Throw_SecurityNegotiationException()
}

[WcfFact]
[Condition(nameof(Skip_CoreWCFService_FailedTest))]
[OuterLoop]
// Verify product throws SecurityNegotiationException when the service cert only has the ClientAuth usage
public static void TCP_ServiceCertInvalidEKU_Throw_SecurityNegotiationException()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#if NET
using CoreWCF.Description;
using CoreWCF.IdentityModel.Selectors;
using CoreWCF.Security;
using CoreWCF.Security.Tokens;

namespace WcfService
{
public class MultiCredentialSecurityTokenManager : ServiceCredentialsSecurityTokenManager
{
private readonly MultiCredentialServiceCredentials _parent;
private readonly Dictionary<string, ServiceCredentials> _map;

public MultiCredentialSecurityTokenManager(
MultiCredentialServiceCredentials parent,
Dictionary<string, ServiceCredentials> map)
: base(parent)
{
_parent = parent;
_map = map;
}

public override SecurityTokenProvider CreateSecurityTokenProvider(SecurityTokenRequirement tokenRequirement)
{
if (tokenRequirement is RecipientServiceModelSecurityTokenRequirement recipientRequirement)
{
var uri = recipientRequirement.ListenUri?.AbsolutePath;
if (!string.IsNullOrEmpty(uri) && _map.TryGetValue(uri, out var creds))
{
return creds.CreateSecurityTokenManager().CreateSecurityTokenProvider(tokenRequirement);
}
}
return _parent.CreateOriginalSecurityTokenManager().CreateSecurityTokenProvider(tokenRequirement);
}

public override SecurityTokenAuthenticator CreateSecurityTokenAuthenticator(SecurityTokenRequirement tokenRequirement, out SecurityTokenResolver outOfBandTokenResolver)
{
if (tokenRequirement is RecipientServiceModelSecurityTokenRequirement recipientRequirement)
{
var uri = recipientRequirement.ListenUri?.AbsolutePath;
if (!string.IsNullOrEmpty(uri) && _map.TryGetValue(uri, out var creds))
{
return creds.CreateSecurityTokenManager().CreateSecurityTokenAuthenticator(tokenRequirement, out outOfBandTokenResolver);
}
}

return _parent.CreateOriginalSecurityTokenManager().CreateSecurityTokenAuthenticator(tokenRequirement, out outOfBandTokenResolver);
}
}
}
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#if NET
using CoreWCF.Description;
using CoreWCF.IdentityModel.Selectors;

namespace WcfService
{
public class MultiCredentialServiceCredentials : ServiceCredentials
{
private readonly Dictionary<string, ServiceCredentials> _serviceCredentialsMap = new();

public void AddServiceCredentials(string path, ServiceCredentials credentials)
{
_serviceCredentialsMap[path] = credentials;
}

public IReadOnlyDictionary<string, ServiceCredentials> ServiceCredentialsMap => _serviceCredentialsMap;

public override SecurityTokenManager CreateSecurityTokenManager()
{
return new MultiCredentialSecurityTokenManager(this, _serviceCredentialsMap);
}

internal SecurityTokenManager CreateOriginalSecurityTokenManager()
{
return base.CreateSecurityTokenManager();
}

protected override ServiceCredentials CloneCore()
{
var clone = new MultiCredentialServiceCredentials();
foreach (var kvp in _serviceCredentialsMap)
{
clone.AddServiceCredentials(kvp.Key, kvp.Value.Clone());
}
return clone;
}
}
}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ public class ServiceHost
private ServiceHostBase _serviceHostBase = null;
private readonly Type _serviceType;
private readonly List<Endpoint> _endpoints = new List<Endpoint>();
private ServiceCredentials _localCredentials = null;

public ServiceHost(Type serviceType, params Uri[] baseAddresses)
{
Expand Down Expand Up @@ -50,7 +51,47 @@ public class Endpoint

public Type ServiceType => _serviceHostBase != null ? _serviceHostBase.Description.ServiceType : _serviceType;

public ServiceCredentials Credentials => _serviceHostBase != null ? _serviceHostBase.Credentials : new ServiceCredentials();
public ServiceCredentials Credentials
{
get
{
if (_localCredentials != null)
{
return _localCredentials;
}

_localCredentials = new ServiceCredentials();
var multiCreds = _serviceHostBase.Credentials as MultiCredentialServiceCredentials;
if (multiCreds == null)
{
throw new Exception("Credentials should have been initialized with MultiCredentialServiceCredentials");
}
var attributes = this.GetType().GetCustomAttributes(typeof(TestServiceDefinitionAttribute), false);
if (attributes != null)
{
foreach (var attribute in attributes)
{
var basePath = "/" + ((TestServiceDefinitionAttribute)attribute).BasePath;
if (!string.IsNullOrEmpty(basePath))
{
foreach (var endpoint in _endpoints)
{
var path = string.IsNullOrEmpty(endpoint.Address) ? basePath : basePath + "/" + endpoint.Address;
if (!multiCreds.ServiceCredentialsMap.TryGetValue(path, out var creds))
{
multiCreds.AddServiceCredentials(path, _localCredentials);
}
else
{
_localCredentials = creds;
}
}
}
}
}
return _localCredentials;
}
}

public ServiceDescription Description => _serviceHostBase != null ? _serviceHostBase.Description : new ServiceDescription();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ internal static async Task<IWebHost> StartHosts(bool useWebSocket)
{
bool success = true;
var serviceTestHostOptionsDict = new Dictionary<string, ServiceTestHostOptions>();
var multiCreds = new MultiCredentialServiceCredentials();

var webHostBuilder = new WebHostBuilder()
.ConfigureLogging((ILoggingBuilder logging) =>
Expand Down Expand Up @@ -256,7 +257,13 @@ internal static async Task<IWebHost> StartHosts(bool useWebSocket)

smb.HttpGetEnabled = true;
}


var creds = serviceHostBase.Description.Behaviors.Find<ServiceCredentials>();
if (creds != null)
{
serviceHostBase.Description.Behaviors.Remove(creds);
}
serviceHostBase.Description.Behaviors.Add(multiCreds);
serviceHost.ApplyConfig(serviceHostBase);
});
}
Expand All @@ -272,8 +279,7 @@ internal static async Task<IWebHost> StartHosts(bool useWebSocket)
Console.BackgroundColor = bg;
Console.ForegroundColor = fg;
}
}

}
});
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,9 @@ public static X509Certificate2 CertificateFromFriendlyName(StoreName name, Store
#endif
foreach (X509Certificate2 cert in foundCertificates)
{
// Search by serial number in Linux/MacOS
if (cert.FriendlyName == friendlyName || cert.SerialNumber == friendlyNameHash)
// Search by friendly name in Windows or by serial number in Linux/MacOS (which is the hash of the friendly name).
// Remove any leading zeros from the number string in certificate SerialNumber using TrimStart('0').
if (cert.FriendlyName == friendlyName || cert.SerialNumber.TrimStart('0') == friendlyNameHash)
{
return cert;
}
Expand Down
Loading