diff --git a/src/MongoDB.Driver/Core/IAsyncCursor.cs b/src/MongoDB.Driver/Core/IAsyncCursor.cs index 55a655f5b2a..37610586fc6 100644 --- a/src/MongoDB.Driver/Core/IAsyncCursor.cs +++ b/src/MongoDB.Driver/Core/IAsyncCursor.cs @@ -360,6 +360,17 @@ public static class IAsyncCursorExtensions return new AsyncCursorEnumerableOneTimeAdapter(cursor, cancellationToken); } + /// + /// Wraps a cursor in an IAsyncEnumerable that can be enumerated one time. + /// + /// The type of the document. + /// The cursor. + /// An IAsyncEnumerable. + public static IAsyncEnumerable ToAsyncEnumerable(this IAsyncCursor cursor) + { + return new AsyncCursorEnumerableOneTimeAdapter(cursor); + } + /// /// Returns a list containing all the documents returned by a cursor. /// diff --git a/src/MongoDB.Driver/Core/IAsyncCursorSource.cs b/src/MongoDB.Driver/Core/IAsyncCursorSource.cs index e6710c5ad49..669a99ffd47 100644 --- a/src/MongoDB.Driver/Core/IAsyncCursorSource.cs +++ b/src/MongoDB.Driver/Core/IAsyncCursorSource.cs @@ -336,6 +336,17 @@ public static class IAsyncCursorSourceExtensions return new AsyncCursorSourceEnumerableAdapter(source, cancellationToken); } + /// + /// Wraps a cursor source in an IAsyncEnumerable. Each time GetAsyncEnumerator is called a new cursor is fetched from the cursor source. + /// + /// The type of the document. + /// The source. + /// An IAsyncEnumerable. + public static IAsyncEnumerable ToAsyncEnumerable(this IAsyncCursorSource source) + { + return new AsyncCursorSourceEnumerableAdapter(source); + } + /// /// Returns a list containing all the documents returned by the cursor returned by a cursor source. /// diff --git a/src/MongoDB.Driver/Core/Operations/AsyncCursorEnumerableOneTimeAdapter.cs b/src/MongoDB.Driver/Core/Operations/AsyncCursorEnumerableOneTimeAdapter.cs index 3303143eaf9..55b47c64a63 100644 --- a/src/MongoDB.Driver/Core/Operations/AsyncCursorEnumerableOneTimeAdapter.cs +++ b/src/MongoDB.Driver/Core/Operations/AsyncCursorEnumerableOneTimeAdapter.cs @@ -21,18 +21,33 @@ namespace MongoDB.Driver.Core.Operations { - internal sealed class AsyncCursorEnumerableOneTimeAdapter : IEnumerable + internal sealed class AsyncCursorEnumerableOneTimeAdapter : IEnumerable, IAsyncEnumerable { private readonly CancellationToken _cancellationToken; private readonly IAsyncCursor _cursor; private bool _hasBeenEnumerated; + public AsyncCursorEnumerableOneTimeAdapter(IAsyncCursor cursor) + : this(cursor, CancellationToken.None) + { + } + public AsyncCursorEnumerableOneTimeAdapter(IAsyncCursor cursor, CancellationToken cancellationToken) { _cursor = Ensure.IsNotNull(cursor, nameof(cursor)); _cancellationToken = cancellationToken; } + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + if (_hasBeenEnumerated) + { + throw new InvalidOperationException("An IAsyncCursor can only be enumerated once."); + } + _hasBeenEnumerated = true; + return new AsyncCursorEnumerator(_cursor, cancellationToken); + } + public IEnumerator GetEnumerator() { if (_hasBeenEnumerated) diff --git a/src/MongoDB.Driver/Core/Operations/AsyncCursorEnumerator.cs b/src/MongoDB.Driver/Core/Operations/AsyncCursorEnumerator.cs index 6778b38146a..2015a9c3136 100644 --- a/src/MongoDB.Driver/Core/Operations/AsyncCursorEnumerator.cs +++ b/src/MongoDB.Driver/Core/Operations/AsyncCursorEnumerator.cs @@ -17,11 +17,12 @@ using System.Collections; using System.Collections.Generic; using System.Threading; +using System.Threading.Tasks; using MongoDB.Driver.Core.Misc; namespace MongoDB.Driver.Core.Operations { - internal class AsyncCursorEnumerator : IEnumerator + internal class AsyncCursorEnumerator : IEnumerator, IAsyncEnumerator { // private fields private IEnumerator _batchEnumerator; @@ -72,6 +73,12 @@ public void Dispose() } } + public ValueTask DisposeAsync() + { + Dispose(); + return default; + } + public bool MoveNext() { ThrowIfDisposed(); @@ -82,24 +89,46 @@ public bool MoveNext() return true; } - while (true) + while (_cursor.MoveNext(_cancellationToken)) { - if (_cursor.MoveNext(_cancellationToken)) + _batchEnumerator?.Dispose(); + _batchEnumerator = _cursor.Current.GetEnumerator(); + if (_batchEnumerator.MoveNext()) { - _batchEnumerator?.Dispose(); - _batchEnumerator = _cursor.Current.GetEnumerator(); - if (_batchEnumerator.MoveNext()) - { - return true; - } + return true; } - else + } + + _batchEnumerator?.Dispose(); + _batchEnumerator = null; + _finished = true; + return false; + } + + public async ValueTask MoveNextAsync() + { + ThrowIfDisposed(); + _started = true; + + if (_batchEnumerator != null && _batchEnumerator.MoveNext()) + { + return true; + } + + while (await _cursor.MoveNextAsync(_cancellationToken).ConfigureAwait(false)) + { + _batchEnumerator?.Dispose(); + _batchEnumerator = _cursor.Current.GetEnumerator(); + if (_batchEnumerator.MoveNext()) { - _batchEnumerator = null; - _finished = true; - return false; + return true; } } + + _batchEnumerator?.Dispose(); + _batchEnumerator = null; + _finished = true; + return false; } public void Reset() diff --git a/src/MongoDB.Driver/Core/Operations/AsyncCursorSourceEnumerableAdapter.cs b/src/MongoDB.Driver/Core/Operations/AsyncCursorSourceEnumerableAdapter.cs index 9d175b1fde8..4d29caa2623 100644 --- a/src/MongoDB.Driver/Core/Operations/AsyncCursorSourceEnumerableAdapter.cs +++ b/src/MongoDB.Driver/Core/Operations/AsyncCursorSourceEnumerableAdapter.cs @@ -13,7 +13,6 @@ * limitations under the License. */ -using System; using System.Collections; using System.Collections.Generic; using System.Threading; @@ -21,19 +20,30 @@ namespace MongoDB.Driver.Core.Operations { - internal class AsyncCursorSourceEnumerableAdapter : IEnumerable + internal class AsyncCursorSourceEnumerableAdapter : IEnumerable, IAsyncEnumerable { // private fields private readonly CancellationToken _cancellationToken; private readonly IAsyncCursorSource _source; // constructors + public AsyncCursorSourceEnumerableAdapter(IAsyncCursorSource source) + : this(source, CancellationToken.None) + { + } + public AsyncCursorSourceEnumerableAdapter(IAsyncCursorSource source, CancellationToken cancellationToken) { _source = Ensure.IsNotNull(source, nameof(source)); _cancellationToken = cancellationToken; } + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + var cursor = _source.ToCursor(cancellationToken); + return new AsyncCursorEnumerator(cursor, cancellationToken); + } + // public methods public IEnumerator GetEnumerator() { diff --git a/src/MongoDB.Driver/Linq/MongoQueryable.cs b/src/MongoDB.Driver/Linq/MongoQueryable.cs index 70c385d4284..48cc283c65d 100644 --- a/src/MongoDB.Driver/Linq/MongoQueryable.cs +++ b/src/MongoDB.Driver/Linq/MongoQueryable.cs @@ -3385,6 +3385,18 @@ public static IQueryable Take(this IQueryable source, Expression.Constant(count))); } + /// + /// Returns an which can be enumerated asynchronously. + /// + /// The type of the elements of . + /// A sequence of values. + /// An IAsyncEnumerable for the query results. + public static IAsyncEnumerable ToAsyncEnumerable(this IQueryable source) + { + var cursorSource = GetCursorSource(source); + return cursorSource.ToAsyncEnumerable(); + } + /// /// Executes the LINQ query and returns a cursor to the results. /// diff --git a/tests/MongoDB.Driver.Tests/Core/IAsyncCursorExtensionsTests.cs b/tests/MongoDB.Driver.Tests/Core/IAsyncCursorExtensionsTests.cs index aff630021b5..e96e4c1a544 100644 --- a/tests/MongoDB.Driver.Tests/Core/IAsyncCursorExtensionsTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/IAsyncCursorExtensionsTests.cs @@ -16,6 +16,8 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; +using System.Threading.Tasks; using FluentAssertions; using MongoDB.Bson; using MongoDB.Bson.Serialization.Serializers; @@ -201,6 +203,55 @@ public void SingleOrDefault_should_throw_when_cursor_has_wrong_number_of_documen action.ShouldThrow(); } + [Fact] + public void ToAsyncEnumerable_result_should_only_be_enumerable_one_time() + { + var cursor = CreateCursor(2); + var enumerable = cursor.ToAsyncEnumerable(); + enumerable.GetAsyncEnumerator(); + + Action action = () => enumerable.GetAsyncEnumerator(); + + action.ShouldThrow(); + } + + [Fact] + public async Task ToAsyncEnumerable_should_respect_cancellation_token() + { + var source = CreateCursor(5); + using var cts = new CancellationTokenSource(); + + var count = 0; + await Assert.ThrowsAsync(async () => + { + await foreach (var doc in source.ToAsyncEnumerable().WithCancellation(cts.Token)) + { + count++; + if (count == 2) + cts.Cancel(); + } + }); + } + + [Fact] + public async Task ToAsyncEnumerable_should_return_expected_result() + { + var cursor = CreateCursor(2); + var expectedDocuments = new[] + { + new BsonDocument("_id", 0), + new BsonDocument("_id", 1) + }; + + var result = new List(); + await foreach (var doc in cursor.ToAsyncEnumerable()) + { + result.Add(doc); + } + + result.Should().Equal(expectedDocuments); + } + [Fact] public void ToEnumerable_result_should_only_be_enumerable_one_time() { diff --git a/tests/MongoDB.Driver.Tests/Core/IAsyncCursorSourceExtensionsTests.cs b/tests/MongoDB.Driver.Tests/Core/IAsyncCursorSourceExtensionsTests.cs index e2b8188ce43..dedd7cd59ae 100644 --- a/tests/MongoDB.Driver.Tests/Core/IAsyncCursorSourceExtensionsTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/IAsyncCursorSourceExtensionsTests.cs @@ -203,6 +203,31 @@ public void SingleOrDefault_should_throw_when_cursor_has_wrong_number_of_documen action.ShouldThrow(); } + [Theory] + [ParameterAttributeData] + public async Task ToAsyncEnumerable_result_should_be_enumerable_multiple_times( + [Values(1, 2)] int times) + { + var source = CreateCursorSource(2); + var expectedDocuments = new[] + { + new BsonDocument("_id", 0), + new BsonDocument("_id", 1) + }; + + var result = new List(); + for (var i = 0; i < times; i++) + { + await foreach (var doc in source.ToAsyncEnumerable()) + { + result.Add(doc); + } + + result.Should().Equal(expectedDocuments); + result.Clear(); + } + } + [Theory] [ParameterAttributeData] public void ToEnumerable_result_should_be_enumerable_multiple_times( diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/MongoQueryableTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/MongoQueryableTests.cs index d68f43e1fe7..d5f527cc47b 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/MongoQueryableTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/MongoQueryableTests.cs @@ -19,7 +19,6 @@ using System.Threading.Tasks; using FluentAssertions; using MongoDB.Bson; -using MongoDB.Driver; using MongoDB.Driver.Core.Clusters; using MongoDB.Driver.Core.Misc; using MongoDB.Driver.Core.TestHelpers.XunitExtensions; @@ -78,6 +77,21 @@ public async Task AnyAsync_with_predicate() result.Should().BeTrue(); } + [Fact] + public async Task ToAsyncEnumerable() + { + var query = CreateQuery().Select(x => x.A); + var expectedResults = query.ToList(); + + var asyncResults = new List(); + await foreach (var item in query.ToAsyncEnumerable().ConfigureAwait(false)) + { + asyncResults.Add(item); + } + + asyncResults.Should().Equal(expectedResults); + } + [Fact] public void Average() {