Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed lazy loading thread safety #35529

Merged
merged 6 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 93 additions & 38 deletions src/EFCore/Infrastructure/Internal/LazyLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using Microsoft.EntityFrameworkCore.Internal;

namespace Microsoft.EntityFrameworkCore.Infrastructure.Internal;
Expand All @@ -20,7 +22,8 @@ public class LazyLoader : ILazyLoader, IInjectableService
private bool _disposed;
private bool _detached;
private IDictionary<string, bool>? _loadedStates;
private readonly ConcurrentDictionary<(object Entity, string NavigationName), bool> _isLoading = new(NavEntryEqualityComparer.Instance);
private readonly Lock _isLoadingLock = new Lock();
private readonly Dictionary<(object Entity, string NavigationName), (TaskCompletionSource TaskCompletionSource, AsyncLocal<int> Depth)> _isLoading = new(NavEntryEqualityComparer.Instance);
private HashSet<string>? _nonLazyNavigations;

/// <summary>
Expand Down Expand Up @@ -107,30 +110,56 @@ public virtual void Load(object entity, [CallerMemberName] string navigationName
Check.NotEmpty(navigationName, nameof(navigationName));

var navEntry = (entity, navigationName);
if (_isLoading.TryAdd(navEntry, true))

bool exists;
(TaskCompletionSource TaskCompletionSource, AsyncLocal<int> Depth) isLoadingValue;

lock (_isLoadingLock)
AndriySvyryd marked this conversation as resolved.
Show resolved Hide resolved
{
ref var refIsLoadingValue = ref CollectionsMarshal.GetValueRefOrAddDefault(_isLoading, navEntry, out exists);
if (!exists)
{
refIsLoadingValue = (new(), new());
}
isLoadingValue = refIsLoadingValue!;
isLoadingValue.Depth.Value++;
}

if (exists)
{
// Only waits for the outermost call on the call stack. See #35528.
if (isLoadingValue.Depth.Value == 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting optimization. It definitely warrants a comment on how it works as the next person to read this code will be very confused.

@roji @cincuranet What's your opinion of this usage of AsyncLocal?

{
isLoadingValue.TaskCompletionSource.Task.Wait();
}
return;
}

try
{
try
// ShouldLoad is called after _isLoading.Add because it could attempt to load the property. See #13138.
if (ShouldLoad(entity, navigationName, out var entry))
{
// ShouldLoad is called after _isLoading.Add because it could attempt to load the property. See #13138.
if (ShouldLoad(entity, navigationName, out var entry))
try
{
try
{
entry.Load(
_queryTrackingBehavior == QueryTrackingBehavior.NoTrackingWithIdentityResolution
? LoadOptions.ForceIdentityResolution
: LoadOptions.None);
}
catch
{
entry.IsLoaded = false;
throw;
}
entry.Load(
_queryTrackingBehavior == QueryTrackingBehavior.NoTrackingWithIdentityResolution
? LoadOptions.ForceIdentityResolution
: LoadOptions.None);
}
catch
{
entry.IsLoaded = false;
throw;
}
}
finally
}
finally
{
isLoadingValue.TaskCompletionSource.TrySetResult();
lock (_isLoadingLock)
{
_isLoading.TryRemove(navEntry, out _);
_isLoading.Remove(navEntry);
}
}
}
Expand All @@ -150,31 +179,57 @@ public virtual async Task LoadAsync(
Check.NotEmpty(navigationName, nameof(navigationName));

var navEntry = (entity, navigationName);
if (_isLoading.TryAdd(navEntry, true))

bool exists;
(TaskCompletionSource TaskCompletionSource, AsyncLocal<int> Depth) isLoadingValue;

lock (_isLoadingLock)
{
ref var refIsLoadingValue = ref CollectionsMarshal.GetValueRefOrAddDefault(_isLoading, navEntry, out exists);
if (!exists)
{
refIsLoadingValue = (new(), new());
}
isLoadingValue = refIsLoadingValue!;
isLoadingValue.Depth.Value++;
}

if (exists)
{
// Only waits for the outermost call on the call stack. See #35528.
if (isLoadingValue.Depth.Value == 1)
{
await isLoadingValue.TaskCompletionSource.Task.WaitAsync(cancellationToken).ConfigureAwait(false);
}
return;
}

try
{
try
// ShouldLoad is called after _isLoading.Add because it could attempt to load the property. See #13138.
if (ShouldLoad(entity, navigationName, out var entry))
{
// ShouldLoad is called after _isLoading.Add because it could attempt to load the property. See #13138.
if (ShouldLoad(entity, navigationName, out var entry))
try
{
try
{
await entry.LoadAsync(
_queryTrackingBehavior == QueryTrackingBehavior.NoTrackingWithIdentityResolution
? LoadOptions.ForceIdentityResolution
: LoadOptions.None,
cancellationToken).ConfigureAwait(false);
}
catch
{
entry.IsLoaded = false;
throw;
}
await entry.LoadAsync(
_queryTrackingBehavior == QueryTrackingBehavior.NoTrackingWithIdentityResolution
? LoadOptions.ForceIdentityResolution
: LoadOptions.None,
cancellationToken).ConfigureAwait(false);
}
catch
{
entry.IsLoaded = false;
throw;
}
}
finally
}
finally
{
isLoadingValue.TaskCompletionSource.TrySetResult();
lock (_isLoadingLock)
{
_isLoading.TryRemove(navEntry, out _);
_isLoading.Remove(navEntry);
}
}
}
Expand Down
131 changes: 131 additions & 0 deletions test/EFCore.Specification.Tests/LoadTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5035,6 +5035,60 @@ public virtual void Setting_navigation_to_null_is_detected_by_local_DetectChange
Assert.Equal(EntityState.Deleted, childEntry.State);
}

[ConditionalTheory] // Issue #35528
[InlineData(false, false)]
[InlineData(true, false)]
[InlineData(false, true)]
[InlineData(true, true)]
public virtual async Task Lazy_loading_is_thread_safe(bool noTracking, bool async)
{
using var context = CreateContext(lazyLoadingEnabled: true);

//Creating another context to avoid caches
using var context2 = CreateContext(lazyLoadingEnabled: true);

IQueryable<Parent> query = context.Set<Parent>();
IQueryable<Parent> query2 = context2.Set<Parent>();

if (noTracking)
{
query = query.AsNoTracking();
query2 = query2.AsNoTracking();
}

var parent = query.Single();

var children = (await parent.LazyLoadChildren(async))?.Select(x => x.Id).OrderBy(x => x).ToList();
var singlePkToPk = (await parent.LazyLoadSinglePkToPk(async))?.Id;
var single = (await parent.LazyLoadSingle(async))?.Id;
var childrenAk = (await parent.LazyLoadChildrenAk(async))?.Select(x => x.Id).OrderBy(x => x).ToList();
var singleAk = (await parent.LazyLoadSingleAk(async))?.Id;
var childrenShadowFk = (await parent.LazyLoadChildrenShadowFk(async))?.Select(x => x.Id).OrderBy(x => x).ToList();
var singleShadowFk = (await parent.LazyLoadSingleShadowFk(async))?.Id;
var childrenCompositeKey = (await parent.LazyLoadChildrenCompositeKey(async))?.Select(x => x.Id).OrderBy(x => x).ToList();
var singleCompositeKey = (await parent.LazyLoadSingleCompositeKey(async))?.Id;

var parent2 = query2.Single();

var parallelOptions = new ParallelOptions
{
MaxDegreeOfParallelism = Environment.ProcessorCount * 500
};

await Parallel.ForAsync(0, 50000, parallelOptions, async (i, ct) =>
{
Assert.Equal(children, (await parent2.LazyLoadChildren(async))?.Select(x => x.Id).OrderBy(x => x).ToList());
Assert.Equal(singlePkToPk, (await parent2.LazyLoadSinglePkToPk(async))?.Id);
Assert.Equal(single, (await parent2.LazyLoadSingle(async))?.Id);
Assert.Equal(childrenAk, (await parent2.LazyLoadChildrenAk(async))?.Select(x => x.Id).OrderBy(x => x).ToList());
Assert.Equal(singleAk, (await parent2.LazyLoadSingleAk(async))?.Id);
Assert.Equal(childrenShadowFk, (await parent2.LazyLoadChildrenShadowFk(async))?.Select(x => x.Id).OrderBy(x => x).ToList());
Assert.Equal(singleShadowFk, (await parent2.LazyLoadSingleShadowFk(async))?.Id);
Assert.Equal(childrenCompositeKey, (await parent2.LazyLoadChildrenCompositeKey(async))?.Select(x => x.Id).OrderBy(x => x).ToList());
Assert.Equal(singleCompositeKey, (await parent2.LazyLoadSingleCompositeKey(async))?.Id);
});
}

private static void SetState(
DbContext context,
object entity,
Expand Down Expand Up @@ -5092,6 +5146,17 @@ public SinglePkToPk SinglePkToPk
set => _singlePkToPk = value;
}

public async Task<SinglePkToPk> LazyLoadSinglePkToPk(bool async)
{
if (async)
{
await Loader.LoadAsync(this, default, nameof(SinglePkToPk));
return _singlePkToPk;
}

return SinglePkToPk;
}

public Single Single
{
get => Loader.Load(this, ref _single);
Expand Down Expand Up @@ -5121,35 +5186,101 @@ public IEnumerable<ChildAk> ChildrenAk
set => _childrenAk = value;
}

public async Task<IEnumerable<ChildAk>> LazyLoadChildrenAk(bool async)
{
if (async)
{
await Loader.LoadAsync(this, default, nameof(ChildrenAk));
return _childrenAk;
}

return ChildrenAk;
}

public SingleAk SingleAk
{
get => Loader.Load(this, ref _singleAk);
set => _singleAk = value;
}

public async Task<SingleAk> LazyLoadSingleAk(bool async)
{
if (async)
{
await Loader.LoadAsync(this, default, nameof(SingleAk));
return _singleAk;
}

return SingleAk;
}

public IEnumerable<ChildShadowFk> ChildrenShadowFk
{
get => Loader.Load(this, ref _childrenShadowFk);
set => _childrenShadowFk = value;
}

public async Task<IEnumerable<ChildShadowFk>> LazyLoadChildrenShadowFk(bool async)
{
if (async)
{
await Loader.LoadAsync(this, default, nameof(ChildrenShadowFk));
return _childrenShadowFk;
}

return ChildrenShadowFk;
}

public SingleShadowFk SingleShadowFk
{
get => Loader.Load(this, ref _singleShadowFk);
set => _singleShadowFk = value;
}

public async Task<SingleShadowFk> LazyLoadSingleShadowFk(bool async)
{
if (async)
{
await Loader.LoadAsync(this, default, nameof(SingleShadowFk));
return _singleShadowFk;
}

return SingleShadowFk;
}

public IEnumerable<ChildCompositeKey> ChildrenCompositeKey
{
get => Loader.Load(this, ref _childrenCompositeKey);
set => _childrenCompositeKey = value;
}

public async Task<IEnumerable<ChildCompositeKey>> LazyLoadChildrenCompositeKey(bool async)
{
if (async)
{
await Loader.LoadAsync(this, default, nameof(ChildrenCompositeKey));
return _childrenCompositeKey;
}

return ChildrenCompositeKey;
}

public SingleCompositeKey SingleCompositeKey
{
get => Loader.Load(this, ref _singleCompositeKey);
set => _singleCompositeKey = value;
}

public async Task<SingleCompositeKey> LazyLoadSingleCompositeKey(bool async)
{
if (async)
{
await Loader.LoadAsync(this, default, nameof(SingleCompositeKey));
return _singleCompositeKey;
}

return SingleCompositeKey;
}
}

protected class Child
Expand Down
Loading