diff --git a/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs b/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs new file mode 100644 index 00000000..723d2e16 --- /dev/null +++ b/src/DistributedLock.Core/CompositeDistributedSynchronizationHandle.cs @@ -0,0 +1,489 @@ +using Medallion.Threading.Internal; + +namespace Medallion.Threading; + +internal sealed class CompositeDistributedSynchronizationHandle : IDistributedSynchronizationHandle +{ + private readonly IDistributedSynchronizationHandle[] _handles; + private readonly CancellationTokenSource? _linkedLostCts; + private bool _disposed; + + public CompositeDistributedSynchronizationHandle(IReadOnlyList handles) + { + ValidateHandles(handles); + this._handles = handles.ToArray(); + this._linkedLostCts = this.CreateLinkedCancellationTokenSource(); + } + + public CancellationToken HandleLostToken => this._linkedLostCts?.Token ?? CancellationToken.None; + + public void Dispose() + { + if (this._disposed) + { + return; + } + + this._disposed = true; + var errors = this.DisposeHandles(h => h.Dispose()); + this._linkedLostCts?.Dispose(); + ThrowAggregateExceptionIfNeeded(errors, "disposing"); + } + + public async ValueTask DisposeAsync() + { + if (this._disposed) + { + return; + } + + this._disposed = true; + var errors = await this.DisposeHandlesAsync(h => h.DisposeAsync()).ConfigureAwait(false); + this._linkedLostCts?.Dispose(); + ThrowAggregateExceptionIfNeeded(errors, "asynchronously disposing"); + } + + public static async ValueTask TryAcquireAllAsync( + TProvider provider, + Func> acquireFunc, + IReadOnlyList names, + TimeSpan timeout = default, + CancellationToken cancellationToken = default) + { + ValidateAcquireParameters(provider, acquireFunc, names); + + var timeoutTracker = new TimeoutTracker(timeout); + var handles = new List(names.Count); + + try + { + foreach (var name in names) + { + var handle = await acquireFunc(provider, name, timeoutTracker.Remaining, cancellationToken) + .ConfigureAwait(false); + + if (handle is null) + { + return null; + } + + handles.Add(handle); + + if (timeoutTracker.IsExpired) + { + return null; + } + } + + var result = new CompositeDistributedSynchronizationHandle(handles); + handles.Clear(); + return result; + } + finally + { + await DisposeHandlesAsync(handles).ConfigureAwait(false); + } + } + + + public static async ValueTask AcquireAllAsync( + TProvider provider, + Func> acquireFunc, + IReadOnlyList names, + TimeSpan? timeout = null, + CancellationToken cancellationToken = default) + { + var effectiveTimeout = timeout ?? Timeout.InfiniteTimeSpan; + var handle = await TryAcquireAllAsync( + provider, + WrapAcquireFunc(acquireFunc), + names, + effectiveTimeout, + cancellationToken) + .ConfigureAwait(false); + + if (handle is null) + { + throw new TimeoutException($"Timed out after {effectiveTimeout} while acquiring all locks."); + } + + return handle; + } + + public static IDistributedSynchronizationHandle? TryAcquireAll( + TProvider provider, + Func acquireFunc, + IReadOnlyList names, + TimeSpan timeout = default, + CancellationToken cancellationToken = default) => + SyncViaAsync.Run( + state => TryAcquireAllAsync( + state.provider, + WrapSyncAcquireFunc(state.acquireFunc), + state.names, + state.timeout, + state.cancellationToken), + (provider, acquireFunc, names, timeout, cancellationToken) + ); + + public static IDistributedSynchronizationHandle AcquireAll( + TProvider provider, + Func acquireFunc, + IReadOnlyList names, + TimeSpan? timeout = null, + CancellationToken cancellationToken = default) => + SyncViaAsync.Run( + state => AcquireAllAsync( + state.provider, + WrapSyncAcquireFuncForRequired(state.acquireFunc), + state.names, + state.timeout, + state.cancellationToken), + (provider, acquireFunc, names, timeout, cancellationToken) + ); + + public static async ValueTask TryAcquireAllAsync( + TProvider provider, + Func> + acquireFunc, + IReadOnlyList names, + int maxCount, + TimeSpan timeout = default, + CancellationToken cancellationToken = default) + { + ValidateAcquireParameters(provider, acquireFunc, names); + + var timeoutTracker = new TimeoutTracker(timeout); + var handles = new List(names.Count); + + try + { + foreach (var name in names) + { + var handle = await acquireFunc(provider, name, maxCount, timeoutTracker.Remaining, cancellationToken) + .ConfigureAwait(false); + + if (handle is null) + { + return null; + } + + handles.Add(handle); + + if (timeoutTracker.IsExpired) + { + return null; + } + } + + var result = new CompositeDistributedSynchronizationHandle(handles); + handles.Clear(); + return result; + } + finally + { + await DisposeHandlesAsync(handles).ConfigureAwait(false); + } + } + + + public static async ValueTask AcquireAllAsync( + TProvider provider, + Func> + acquireFunc, + IReadOnlyList names, + int maxCount, + TimeSpan? timeout = null, + CancellationToken cancellationToken = default) + { + var effectiveTimeout = timeout ?? Timeout.InfiniteTimeSpan; + var handle = await TryAcquireAllAsync( + provider, + WrapAcquireFunc(acquireFunc), + names, + maxCount, + effectiveTimeout, + cancellationToken) + .ConfigureAwait(false); + + if (handle is null) + { + throw new TimeoutException($"Timed out after {effectiveTimeout} while acquiring all locks."); + } + + return handle; + } + + public static IDistributedSynchronizationHandle? TryAcquireAll( + TProvider provider, + Func acquireFunc, + IReadOnlyList names, + int maxCount, + TimeSpan timeout = default, + CancellationToken cancellationToken = default) => + SyncViaAsync.Run( + state => TryAcquireAllAsync( + state.provider, + WrapSyncAcquireFunc(state.acquireFunc), + state.names, + state.maxCount, + state.timeout, + state.cancellationToken), + (provider, acquireFunc, names, maxCount, timeout, cancellationToken) + ); + + public static IDistributedSynchronizationHandle AcquireAll( + TProvider provider, + Func acquireFunc, + IReadOnlyList names, + int maxCount, + TimeSpan? timeout = null, + CancellationToken cancellationToken = default) => + SyncViaAsync.Run( + state => AcquireAllAsync( + state.provider, + WrapSyncAcquireFuncForRequired(state.acquireFunc), + state.names, + state.maxCount, + state.timeout, + state.cancellationToken), + (provider, acquireFunc, names, maxCount, timeout, cancellationToken) + ); + + private static void ValidateHandles(IReadOnlyList handles) + { + if (handles is null) + { + throw new ArgumentNullException(nameof(handles)); + } + + if (handles.Count == 0) + { + throw new ArgumentException("At least one handle is required", nameof(handles)); + } + + for (var i = 0; i < handles.Count; ++i) + { + if (handles[i] is null) + { + throw new ArgumentException( + $"Handles must not contain null elements; found null at index {i}", + nameof(handles) + ); + } + } + } + + private CancellationTokenSource? CreateLinkedCancellationTokenSource() + { + var cancellableTokens = this._handles + .Select(h => h.HandleLostToken) + .Where(t => t.CanBeCanceled) + .ToArray(); + + return cancellableTokens.Length > 0 + ? CancellationTokenSource.CreateLinkedTokenSource(cancellableTokens) + : null; + } + + private List? DisposeHandles(Action disposeAction) + { + List? errors = null; + + foreach (var handle in this._handles) + { + try + { + disposeAction(handle); + } + catch (Exception ex) + { + (errors ??= []).Add(ex); + } + } + + return errors; + } + + private async ValueTask?> DisposeHandlesAsync( + Func disposeAction) + { + List? errors = null; + + foreach (var handle in this._handles) + { + try + { + await disposeAction(handle).ConfigureAwait(false); + } + catch (Exception ex) + { + (errors ??= []).Add(ex); + } + } + + return errors; + } + + private static void ThrowAggregateExceptionIfNeeded(List? errors, string operation) + { + if (errors is not null && errors.Count > 0) + { + throw new AggregateException( + $"One or more errors occurred while {operation} a composite distributed handle.", errors); + } + } + + private static void ValidateAcquireParameters( + TProvider provider, + Func> acquireFunc, + IReadOnlyList names) + { + if (provider is null) + { + throw new ArgumentNullException(nameof(provider)); + } + + if (acquireFunc is null) + { + throw new ArgumentNullException(nameof(acquireFunc)); + } + + if (names is null) + { + throw new ArgumentNullException(nameof(names)); + } + + if (names.Count == 0) + { + throw new ArgumentException("At least one lock name is required.", nameof(names)); + } + + for (var i = 0; i < names.Count; ++i) + { + if (names[i] is null) + { + throw new ArgumentException( + $"Names must not contain null elements; found null at index {i}", + nameof(names) + ); + } + } + } + + private static void ValidateAcquireParameters( + TProvider provider, + Func> + acquireFunc, + IReadOnlyList names) + { + if (provider is null) + { + throw new ArgumentNullException(nameof(provider)); + } + + if (acquireFunc is null) + { + throw new ArgumentNullException(nameof(acquireFunc)); + } + + if (names is null) + { + throw new ArgumentNullException(nameof(names)); + } + + if (names.Count == 0) + { + throw new ArgumentException("At least one lock name is required.", nameof(names)); + } + + for (var i = 0; i < names.Count; ++i) + { + if (names[i] is null) + { + throw new ArgumentException( + $"Names must not contain null elements; found null at index {i}", + nameof(names) + ); + } + } + } + + private static async ValueTask DisposeHandlesAsync(List handles) + { + foreach (var handle in handles) + { + try + { + await handle.DisposeAsync().ConfigureAwait(false); + } + catch + { + // Suppress exceptions during cleanup + } + } + } + + private static Func> + WrapAcquireFunc( + Func> + acquireFunc) => + async (p, n, t, c) => await acquireFunc(p, n, t, c).ConfigureAwait(false); + + private static Func> + WrapSyncAcquireFunc( + Func acquireFunc) => + (p, n, t, c) => new ValueTask(acquireFunc(p, n, t, c)); + + private static Func> + WrapSyncAcquireFuncForRequired( + Func acquireFunc) => + (p, n, t, c) => + { + var handle = acquireFunc(p, n, t, c); + return handle is not null + ? new ValueTask(handle) + : throw new TimeoutException($"Failed to acquire lock for '{n}'"); + }; + + + private static Func> + WrapAcquireFunc( + Func> + acquireFunc) => + async (p, n, mc, t, c) => await acquireFunc(p, n, mc, t, c).ConfigureAwait(false); + + private static Func> + WrapSyncAcquireFunc( + Func + acquireFunc) => + (p, n, mc, t, c) => new ValueTask(acquireFunc(p, n, mc, t, c)); + + private static Func> + WrapSyncAcquireFuncForRequired( + Func + acquireFunc) => + (p, n, mc, t, c) => + { + var handle = acquireFunc(p, n, mc, t, c); + return handle is not null + ? new ValueTask(handle) + : throw new TimeoutException($"Failed to acquire lock for '{n}'"); + }; + + private sealed class TimeoutTracker(TimeSpan timeout) + { + private readonly System.Diagnostics.Stopwatch? _stopwatch = timeout == Timeout.InfiniteTimeSpan + ? null + : System.Diagnostics.Stopwatch.StartNew(); + + public TimeSpan Remaining => this._stopwatch is null + ? Timeout.InfiniteTimeSpan + : timeout - this._stopwatch.Elapsed; + + public bool IsExpired => this._stopwatch is not null && this._stopwatch.Elapsed >= timeout; + } +} \ No newline at end of file diff --git a/src/DistributedLock.Core/DistributedLockProviderExtensions.cs b/src/DistributedLock.Core/DistributedLockProviderExtensions.cs index ea4ad6f5..bc98c6bb 100644 --- a/src/DistributedLock.Core/DistributedLockProviderExtensions.cs +++ b/src/DistributedLock.Core/DistributedLockProviderExtensions.cs @@ -7,6 +7,8 @@ namespace Medallion.Threading; /// public static class DistributedLockProviderExtensions { + # region Single Lock Methods + /// /// Equivalent to calling and then /// . @@ -34,4 +36,50 @@ public static IDistributedSynchronizationHandle AcquireLock(this IDistributedLoc /// public static ValueTask AcquireLockAsync(this IDistributedLockProvider provider, string name, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => (provider ?? throw new ArgumentNullException(nameof(provider))).CreateLock(name).AcquireAsync(timeout, cancellationToken); + + # endregion + + # region Composite Lock Methods + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle? TryAcquireAllLocks(this IDistributedLockProvider provider, IReadOnlyList names, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAll( + provider, + static (p, n, t, c) => p.TryAcquireLock(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle AcquireAllLocks(this IDistributedLockProvider provider, IReadOnlyList names, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAll( + provider, + static (p, n, t, c) => p.AcquireLock(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask TryAcquireAllLocksAsync(this IDistributedLockProvider provider, IReadOnlyList names, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAllAsync( + provider, + static (p, n, t, c) => p.TryAcquireLockAsync(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask AcquireAllLocksAsync(this IDistributedLockProvider provider, IReadOnlyList names, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAllAsync( + provider, + static (p, n, t, c) => p.AcquireLockAsync(n, t, c), + names, timeout, cancellationToken); + + # endregion } \ No newline at end of file diff --git a/src/DistributedLock.Core/DistributedReaderWriterLockProviderExtensions.cs b/src/DistributedLock.Core/DistributedReaderWriterLockProviderExtensions.cs index 5ac66fac..da5f407c 100644 --- a/src/DistributedLock.Core/DistributedReaderWriterLockProviderExtensions.cs +++ b/src/DistributedLock.Core/DistributedReaderWriterLockProviderExtensions.cs @@ -7,6 +7,8 @@ namespace Medallion.Threading; /// public static class DistributedReaderWriterLockProviderExtensions { + # region Single Lock Methods + /// /// Equivalent to calling and then /// . @@ -62,4 +64,90 @@ public static IDistributedSynchronizationHandle AcquireWriteLock(this IDistribut /// public static ValueTask AcquireWriteLockAsync(this IDistributedReaderWriterLockProvider provider, string name, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => (provider ?? throw new ArgumentNullException(nameof(provider))).CreateReaderWriterLock(name).AcquireWriteLockAsync(timeout, cancellationToken); + + # endregion + + # region Composite Lock Methods + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle? TryAcquireAllReadLocks(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAll( + provider, + static (p, n, t, c) => p.TryAcquireReadLock(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle AcquireAllReadLocks(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAll( + provider, + static (p, n, t, c) => p.AcquireReadLock(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask TryAcquireAllReadLocksAsync(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAllAsync( + provider, + static (p, n, t, c) => p.TryAcquireReadLockAsync(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask AcquireAllReadLocksAsync(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAllAsync( + provider, + static (p, n, t, c) => p.AcquireReadLockAsync(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle? TryAcquireAllWriteLocks(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAll( + provider, + static (p, n, t, c) => p.TryAcquireWriteLock(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle AcquireAllWriteLocks(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAll( + provider, + static (p, n, t, c) => p.AcquireWriteLock(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask TryAcquireAllWriteLocksAsync(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAllAsync( + provider, + static (p, n, t, c) => p.TryAcquireWriteLockAsync(n, t, c), + names, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask AcquireAllWriteLocksAsync(this IDistributedReaderWriterLockProvider provider, IReadOnlyList names, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAllAsync( + provider, + static (p, n, t, c) => p.AcquireWriteLockAsync(n, t, c), + names, timeout, cancellationToken); + + # endregion } \ No newline at end of file diff --git a/src/DistributedLock.Core/DistributedSemaphoreProviderExtensions.cs b/src/DistributedLock.Core/DistributedSemaphoreProviderExtensions.cs index b808b004..3ed29ebe 100644 --- a/src/DistributedLock.Core/DistributedSemaphoreProviderExtensions.cs +++ b/src/DistributedLock.Core/DistributedSemaphoreProviderExtensions.cs @@ -7,6 +7,8 @@ namespace Medallion.Threading; /// public static class DistributedSemaphoreProviderExtensions { + # region Single Lock Methods + /// /// Equivalent to calling and then /// . @@ -34,4 +36,50 @@ public static IDistributedSynchronizationHandle AcquireSemaphore(this IDistribut /// public static ValueTask AcquireSemaphoreAsync(this IDistributedSemaphoreProvider provider, string name, int maxCount, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => (provider ?? throw new ArgumentNullException(nameof(provider))).CreateSemaphore(name, maxCount).AcquireAsync(timeout, cancellationToken); + + # endregion + + # region Composite Lock Methods + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle? TryAcquireAllSemaphores(this IDistributedSemaphoreProvider provider, IReadOnlyList names, int maxCount, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAll( + provider, + static (p, n, mc, t, c) => p.TryAcquireSemaphore(n, mc, t, c), + names, maxCount, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static IDistributedSynchronizationHandle AcquireAllSemaphores(this IDistributedSemaphoreProvider provider, IReadOnlyList names, int maxCount, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAll( + provider, + static (p, n, mc, t, c) => p.AcquireSemaphore(n, mc, t, c), + names, maxCount, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask TryAcquireAllSemaphoresAsync(this IDistributedSemaphoreProvider provider, IReadOnlyList names, int maxCount, TimeSpan timeout = default, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.TryAcquireAllAsync( + provider, + static (p, n, mc, t, c) => p.TryAcquireSemaphoreAsync(n, mc, t, c), + names, maxCount, timeout, cancellationToken); + + /// + /// Equivalent to calling for each name in and then + /// on each created instance, combining the results into a composite handle. + /// + public static ValueTask AcquireAllSemaphoresAsync(this IDistributedSemaphoreProvider provider, IReadOnlyList names, int maxCount, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => + CompositeDistributedSynchronizationHandle.AcquireAllAsync( + provider, + static (p, n, mc, t, c) => p.AcquireSemaphoreAsync(n, mc, t, c), + names, maxCount, timeout, cancellationToken); + + # endregion } \ No newline at end of file diff --git a/src/DistributedLock.Core/DistributedUpgradeableReaderWriterLockProviderExtensions.cs b/src/DistributedLock.Core/DistributedUpgradeableReaderWriterLockProviderExtensions.cs index 4b8c51f9..2bc1c429 100644 --- a/src/DistributedLock.Core/DistributedUpgradeableReaderWriterLockProviderExtensions.cs +++ b/src/DistributedLock.Core/DistributedUpgradeableReaderWriterLockProviderExtensions.cs @@ -7,6 +7,8 @@ namespace Medallion.Threading; /// public static class DistributedUpgradeableReaderWriterLockProviderExtensions { + # region Single Lock Methods + /// /// Equivalent to calling and then /// . @@ -34,4 +36,12 @@ public static IDistributedLockUpgradeableHandle AcquireUpgradeableReadLock(this /// public static ValueTask AcquireUpgradeableReadLockAsync(this IDistributedUpgradeableReaderWriterLockProvider provider, string name, TimeSpan? timeout = null, CancellationToken cancellationToken = default) => (provider ?? throw new ArgumentNullException(nameof(provider))).CreateUpgradeableReaderWriterLock(name).AcquireUpgradeableReadLockAsync(timeout, cancellationToken); + + # endregion + + # region Composite Lock Methods + +// Composite methods are not supported for IDistributedUpgradeableReaderWriterLock + + # endregion } \ No newline at end of file diff --git a/src/DistributedLockCodeGen/GenerateProviders.cs b/src/DistributedLockCodeGen/GenerateProviders.cs index ab657b14..80c80747 100644 --- a/src/DistributedLockCodeGen/GenerateProviders.cs +++ b/src/DistributedLockCodeGen/GenerateProviders.cs @@ -10,64 +10,112 @@ namespace DistributedLockCodeGen; [Category("CI")] public class GenerateProviders { - public static readonly IReadOnlyList Interfaces = new[] - { + public static readonly IReadOnlyList Interfaces = + [ "IDistributedLock", "IDistributedReaderWriterLock", "IDistributedUpgradeableReaderWriterLock", "IDistributedSemaphore" - }; + ]; + + private static readonly IReadOnlyList ExcludedInterfacesForCompositeMethods = + [ + "IDistributedUpgradeableReaderWriterLock" + ]; [TestCaseSource(nameof(Interfaces))] public void GenerateProviderInterfaceAndExtensions(string interfaceName) { - var interfaceFile = Directory.GetFiles(CodeGenHelpers.SolutionDirectory, interfaceName + ".cs", SearchOption.AllDirectories) + var interfaceFile = Directory + .GetFiles(CodeGenHelpers.SolutionDirectory, interfaceName + ".cs", SearchOption.AllDirectories) .Single(); var providerInterfaceName = interfaceName + "Provider"; var createMethodName = $"Create{interfaceName.Replace("IDistributed", string.Empty)}"; - var providerInterfaceCode = $@"// AUTO-GENERATED -namespace Medallion.Threading; - -/// -/// Acts as a factory for instances of a certain type. This interface may be -/// easier to use than in dependency injection scenarios. -/// -public interface {providerInterfaceName}{(interfaceName == "IDistributedUpgradeableReaderWriterLock" ? ": IDistributedReaderWriterLockProvider" : string.Empty)} -{{ - /// - /// Constructs an instance with the given . - /// - {interfaceName} {createMethodName}(string name{(interfaceName.Contains("Semaphore") ? ", int maxCount" : string.Empty)}); -}}"; + var providerInterfaceCode = $$""" + // AUTO-GENERATED + namespace Medallion.Threading; + + /// + /// Acts as a factory for instances of a certain type. This interface may be + /// easier to use than in dependency injection scenarios. + /// + public interface {{providerInterfaceName}}{{(interfaceName == "IDistributedUpgradeableReaderWriterLock" ? ": IDistributedReaderWriterLockProvider" : string.Empty)}} + { + /// + /// Constructs an instance with the given . + /// + {{interfaceName}} {{createMethodName}}(string name{{(interfaceName.Contains("Semaphore") ? ", int maxCount" : string.Empty)}}); + } + """; var interfaceMethods = Regex.Matches( File.ReadAllText(interfaceFile), @"(?\S+) (?\S+)\((?((?\S*) (?\w+)[^,)]*(\, )?)*)\);", RegexOptions.ExplicitCapture ); - var extensionMethodBodies = interfaceMethods.Cast() + + var extensionSingleMethodBodies = interfaceMethods .Select(m => -$@" /// - /// Equivalent to calling and then - /// ().Select(c => c.Value))})"" />. - /// - public static {m.Groups["returnType"].Value} {GetExtensionMethodName(m.Groups["name"].Value)}(this {providerInterfaceName} provider, string name{(interfaceName.Contains("Semaphore") ? ", int maxCount" : string.Empty)}, {m.Groups["parameters"].Value}) => - (provider ?? throw new ArgumentNullException(nameof(provider))).{createMethodName}(name{(interfaceName.Contains("Semaphore") ? ", maxCount" : string.Empty)}).{m.Groups["name"].Value}({string.Join(", ", m.Groups["parameterName"].Captures.Cast().Select(c => c.Value))});" + $""" + /// + /// Equivalent to calling and then + /// c.Value))})" />. + /// + public static {m.Groups["returnType"].Value} {GetExtensionMethodName(m.Groups["name"].Value)}(this {providerInterfaceName} provider, string name{(interfaceName.Contains("Semaphore") ? ", int maxCount" : string.Empty)}, {m.Groups["parameters"].Value}) => + (provider ?? throw new ArgumentNullException(nameof(provider))).{createMethodName}(name{(interfaceName.Contains("Semaphore") ? ", maxCount" : string.Empty)}).{m.Groups["name"].Value}({string.Join(", ", m.Groups["parameterName"].Captures.Select(c => c.Value))}); + """ ); + var extensionCompositeMethodBodies = ExcludedInterfacesForCompositeMethods.Contains(interfaceName) + ? + [ + $"// Composite methods are not supported for {interfaceName}" + ] + : interfaceMethods + .Select(m => + { + var (extensionMethodName, innerCallName) = GetAllExtensionMethodName(m.Groups["name"].Value); + var isSemaphore = interfaceName.Contains("Semaphore"); + + return $""" + /// + /// Equivalent to calling for each name in and then + /// c.Value))})" /> on each created instance, combining the results into a composite handle. + /// + public static {m.Groups["returnType"].Value} {extensionMethodName}(this {providerInterfaceName} provider, IReadOnlyList names{(isSemaphore ? ", int maxCount" : string.Empty)}, {m.Groups["parameters"].Value}) => + CompositeDistributedSynchronizationHandle.{innerCallName}( + provider, + static (p, n{(isSemaphore ? ", mc" : string.Empty)}, t, c) => p.{GetExtensionMethodName(m.Groups["name"].Value)}(n{(isSemaphore ? ", mc" : string.Empty)}, t, c), + names,{(isSemaphore ? " maxCount," : string.Empty)} timeout, cancellationToken); + """; + } + ); + var providerExtensionsName = providerInterfaceName.TrimStart('I') + "Extensions"; - var providerExtensionsCode = $@"// AUTO-GENERATED + var providerExtensionsCode = $$""" + // AUTO-GENERATED + + namespace Medallion.Threading; + + /// + /// Productivity helper methods for + /// + public static class {{providerExtensionsName}} + { + # region Single Lock Methods + + {{string.Join(Environment.NewLine + Environment.NewLine, extensionSingleMethodBodies)}} + + # endregion + + # region Composite Lock Methods -namespace Medallion.Threading; + {{string.Join(Environment.NewLine + Environment.NewLine, extensionCompositeMethodBodies)}} -/// -/// Productivity helper methods for -/// -public static class {providerExtensionsName} -{{ -{string.Join(Environment.NewLine + Environment.NewLine, extensionMethodBodies)} -}}"; + # endregion + } + """; var changes = new[] { @@ -76,7 +124,8 @@ public static class {providerExtensionsName} } .Select(t => (file: Path.Combine(Path.GetDirectoryName(interfaceFile)!, t.name + ".cs"), t.code)) .Select(t => (t.file, t.code, originalCode: File.Exists(t.file) ? File.ReadAllText(t.file) : string.Empty)) - .Where(t => CodeGenHelpers.NormalizeCodeWhitespace(t.code) != CodeGenHelpers.NormalizeCodeWhitespace(t.originalCode)) + .Where(t => CodeGenHelpers.NormalizeCodeWhitespace(t.code) != + CodeGenHelpers.NormalizeCodeWhitespace(t.originalCode)) .ToList(); changes.ForEach(t => File.WriteAllText(t.file, t.code)); Assert.That(changes.Select(t => t.file), Is.Empty); @@ -85,8 +134,45 @@ string GetExtensionMethodName(string interfaceMethodName) => Regex.IsMatch(interfaceMethodName, "^(Try)?Acquire(Async)?$") // make it more specific to differentiate when one concrete provider implements multiple provider interfaces ? interfaceMethodName.Replace("Async", string.Empty) - + interfaceName.Replace("IDistributed", string.Empty) - + (interfaceMethodName.EndsWith("Async") ? "Async" : string.Empty) + + interfaceName.Replace("IDistributed", string.Empty) + + (interfaceMethodName.EndsWith("Async") ? "Async" : string.Empty) : interfaceMethodName; + + (string extensionMethodName, string innerCallName) GetAllExtensionMethodName(string interfaceMethodName) + { + var isExactAcquire = Regex.IsMatch(interfaceMethodName, "^(Try)?Acquire(Async)?$"); + var isAsync = interfaceMethodName.EndsWith("Async", StringComparison.Ordinal); + var isTryVariant = interfaceMethodName.StartsWith("Try", StringComparison.Ordinal); + + string extensionMethodName; + + if (!isExactAcquire) + { + // e.g. TryAcquireReadLock -> TryAcquireAllReadLocks + // TryAcquireSemaphore -> TryAcquireAllSemaphores + // TryAcquireUpgradeableReadLockAsync -> TryAcquireUpgradeableAllReadLockAsync + extensionMethodName = interfaceMethodName + .Replace("Acquire", "AcquireAll") // Acquire -> AcquireAll + .Replace("Async", string.Empty) // strip Async (add back later) + + "s" // pluralise + + (isAsync ? "Async" : string.Empty); // restore Async if needed + } + else + { + // e.g. TryAcquire -> TryAcquireAllLocks + // AcquireAsync -> AcquireAllLocksAsync + extensionMethodName = interfaceMethodName.Replace("Async", string.Empty) + + "All" + + interfaceName.Replace("IDistributed", string.Empty) + "s" + + (isAsync ? "Async" : string.Empty); + } + + // - “Try…” methods -> TryAcquireAll[Async] + // - plain Acquire… -> AcquireAll[Async] + var innerCallName = (isTryVariant ? "TryAcquireAll" : "AcquireAll") + + (isAsync ? "Async" : string.Empty); + + return (extensionMethodName, innerCallName); + } } -} +} \ No newline at end of file