Skip to content

Commit

Permalink
Merge branch 'master' into survivedMemory
Browse files Browse the repository at this point in the history
  • Loading branch information
timcassell authored Apr 1, 2024
2 parents 0d6500f + 4ab69be commit d830775
Show file tree
Hide file tree
Showing 11 changed files with 458 additions and 198 deletions.
14 changes: 5 additions & 9 deletions src/BenchmarkDotNet/Code/DeclarationsProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ private string GetMethodName(MethodInfo method)
(method.ReturnType.GetGenericTypeDefinition() == typeof(Task<>) ||
method.ReturnType.GetGenericTypeDefinition() == typeof(ValueTask<>))))
{
return $"() => {method.Name}().GetAwaiter().GetResult()";
return $"() => BenchmarkDotNet.Helpers.AwaitHelper.GetResult({method.Name}())";
}

return method.Name;
Expand Down Expand Up @@ -149,12 +149,10 @@ internal class TaskDeclarationsProvider : VoidDeclarationsProvider
{
public TaskDeclarationsProvider(Descriptor descriptor) : base(descriptor) { }

// we use GetAwaiter().GetResult() because it's fastest way to obtain the result in blocking way,
// and will eventually throw actual exception, not aggregated one
public override string WorkloadMethodDelegate(string passArguments)
=> $"({passArguments}) => {{ {Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult(); }}";
=> $"({passArguments}) => {{ BenchmarkDotNet.Helpers.AwaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments})); }}";

public override string GetWorkloadMethodCall(string passArguments) => $"{Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult()";
public override string GetWorkloadMethodCall(string passArguments) => $"BenchmarkDotNet.Helpers.AwaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments}))";

protected override Type WorkloadMethodReturnType => typeof(void);
}
Expand All @@ -168,11 +166,9 @@ public GenericTaskDeclarationsProvider(Descriptor descriptor) : base(descriptor)

protected override Type WorkloadMethodReturnType => Descriptor.WorkloadMethod.ReturnType.GetTypeInfo().GetGenericArguments().Single();

// we use GetAwaiter().GetResult() because it's fastest way to obtain the result in blocking way,
// and will eventually throw actual exception, not aggregated one
public override string WorkloadMethodDelegate(string passArguments)
=> $"({passArguments}) => {{ return {Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult(); }}";
=> $"({passArguments}) => {{ return BenchmarkDotNet.Helpers.AwaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments})); }}";

public override string GetWorkloadMethodCall(string passArguments) => $"{Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult()";
public override string GetWorkloadMethodCall(string passArguments) => $"BenchmarkDotNet.Helpers.AwaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments}))";
}
}
34 changes: 24 additions & 10 deletions src/BenchmarkDotNet/Engines/Engine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,6 @@ public Measurement RunIteration(IterationData data)
if (EngineEventSource.Log.IsEnabled())
EngineEventSource.Log.IterationStart(data.IterationMode, data.IterationStage, totalOperations);

Span<byte> stackMemory = randomizeMemory ? stackalloc byte[random.Next(32)] : Span<byte>.Empty;

bool needsSurvivedMeasurement = includeSurvivedMemory && !isOverhead && !survivedBytesMeasured;
if (needsSurvivedMeasurement && GcStats.InitTotalBytes())
{
Expand All @@ -192,10 +190,9 @@ public Measurement RunIteration(IterationData data)
survivedBytes = afterBytes - beforeBytes;
}

// Measure
var clock = Clock.Start();
action(invokeCount / unrollFactor);
var clockSpan = clock.GetElapsed();
var clockSpan = randomizeMemory
? MeasureWithRandomMemory(action, invokeCount / unrollFactor)
: Measure(action, invokeCount / unrollFactor);

if (EngineEventSource.Log.IsEnabled())
EngineEventSource.Log.IterationStop(data.IterationMode, data.IterationStage, totalOperations);
Expand All @@ -214,9 +211,29 @@ public Measurement RunIteration(IterationData data)
if (measurement.IterationStage == IterationStage.Jitting)
jittingMeasurements.Add(measurement);

return measurement;
}

// This is in a separate method, because stackalloc can affect code alignment,
// resulting in unexpected measurements on some AMD cpus,
// even if the stackalloc branch isn't executed. (#2366)
[MethodImpl(MethodImplOptions.NoInlining)]
private unsafe ClockSpan MeasureWithRandomMemory(Action<long> action, long invokeCount)
{
byte* stackMemory = stackalloc byte[random.Next(32)];
var clockSpan = Measure(action, invokeCount);
Consume(stackMemory);
return clockSpan;
}

return measurement;
[MethodImpl(MethodImplOptions.NoInlining)]
private unsafe void Consume(byte* _) { }

private ClockSpan Measure(Action<long> action, long invokeCount)
{
var clock = Clock.Start();
action(invokeCount);
return clock.GetElapsed();
}

private (GcStats, ThreadingStats, double) GetExtraStats(IterationData data)
Expand Down Expand Up @@ -248,9 +265,6 @@ public Measurement RunIteration(IterationData data)
return (gcStats, threadingStats, exceptionsStats.ExceptionsCount / (double)totalOperationsCount);
}

[MethodImpl(MethodImplOptions.NoInlining)]
private void Consume(in Span<byte> _) { }

private void RandomizeManagedHeapMemory()
{
// invoke global cleanup before global setup
Expand Down
108 changes: 108 additions & 0 deletions src/BenchmarkDotNet/Helpers/AwaitHelper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
using System;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

namespace BenchmarkDotNet.Helpers
{
public static class AwaitHelper
{
private class ValueTaskWaiter
{
// We use thread static field so that each thread uses its own individual callback and reset event.
[ThreadStatic]
private static ValueTaskWaiter ts_current;
internal static ValueTaskWaiter Current => ts_current ??= new ValueTaskWaiter();

// We cache the callback to prevent allocations for memory diagnoser.
private readonly Action awaiterCallback;
private readonly ManualResetEventSlim resetEvent;

private ValueTaskWaiter()
{
resetEvent = new ();
awaiterCallback = resetEvent.Set;
}

internal void Wait<TAwaiter>(TAwaiter awaiter) where TAwaiter : ICriticalNotifyCompletion
{
resetEvent.Reset();
awaiter.UnsafeOnCompleted(awaiterCallback);

// The fastest way to wait for completion is to spin a bit before waiting on the event. This is the same logic that Task.GetAwaiter().GetResult() uses.
var spinner = new SpinWait();
while (!resetEvent.IsSet)
{
if (spinner.NextSpinWillYield)
{
resetEvent.Wait();
return;
}
spinner.SpinOnce();
}
}
}

// we use GetAwaiter().GetResult() because it's fastest way to obtain the result in blocking way,
// and will eventually throw actual exception, not aggregated one
public static void GetResult(Task task) => task.GetAwaiter().GetResult();

public static T GetResult<T>(Task<T> task) => task.GetAwaiter().GetResult();

// ValueTask can be backed by an IValueTaskSource that only supports asynchronous awaits,
// so we have to hook up a callback instead of calling .GetAwaiter().GetResult() like we do for Task.
// The alternative is to convert it to Task using .AsTask(), but that causes allocations which we must avoid for memory diagnoser.
public static void GetResult(ValueTask task)
{
// Don't continue on the captured context, as that may result in a deadlock if the user runs this in-process.
var awaiter = task.ConfigureAwait(false).GetAwaiter();
if (!awaiter.IsCompleted)
{
ValueTaskWaiter.Current.Wait(awaiter);
}
awaiter.GetResult();
}

public static T GetResult<T>(ValueTask<T> task)
{
// Don't continue on the captured context, as that may result in a deadlock if the user runs this in-process.
var awaiter = task.ConfigureAwait(false).GetAwaiter();
if (!awaiter.IsCompleted)
{
ValueTaskWaiter.Current.Wait(awaiter);
}
return awaiter.GetResult();
}

internal static MethodInfo GetGetResultMethod(Type taskType)
{
if (!taskType.IsGenericType)
{
return typeof(AwaitHelper).GetMethod(nameof(AwaitHelper.GetResult), BindingFlags.Public | BindingFlags.Static, null, new Type[1] { taskType }, null);
}

Type compareType = taskType.GetGenericTypeDefinition() == typeof(ValueTask<>) ? typeof(ValueTask<>)
: typeof(Task).IsAssignableFrom(taskType.GetGenericTypeDefinition()) ? typeof(Task<>)
: null;
if (compareType == null)
{
return null;
}
var resultType = taskType
.GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance)
.ReturnType
.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlags.Public | BindingFlags.Instance)
.ReturnType;
return typeof(AwaitHelper).GetMethods(BindingFlags.Public | BindingFlags.Static)
.First(m =>
{
if (m.Name != nameof(AwaitHelper.GetResult)) return false;
Type paramType = m.GetParameters().First().ParameterType;
return paramType.IsGenericType && paramType.GetGenericTypeDefinition() == compareType;
})
.MakeGenericMethod(new[] { resultType });
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using BenchmarkDotNet.Engines;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
Expand All @@ -16,28 +18,24 @@ public ConsumableTypeInfo(Type methodReturnType)

OriginMethodReturnType = methodReturnType;

// Please note this code does not support await over extension methods.
var getAwaiterMethod = methodReturnType.GetMethod(nameof(Task<int>.GetAwaiter), BindingFlagsPublicInstance);
if (getAwaiterMethod == null)
// Only support (Value)Task for parity with other toolchains (and so we can use AwaitHelper).
IsAwaitable = methodReturnType == typeof(Task) || methodReturnType == typeof(ValueTask)
|| (methodReturnType.GetTypeInfo().IsGenericType
&& (methodReturnType.GetTypeInfo().GetGenericTypeDefinition() == typeof(Task<>)
|| methodReturnType.GetTypeInfo().GetGenericTypeDefinition() == typeof(ValueTask<>)));

if (!IsAwaitable)
{
WorkloadMethodReturnType = methodReturnType;
}
else
{
var getResultMethod = getAwaiterMethod
WorkloadMethodReturnType = methodReturnType
.GetMethod(nameof(Task.GetAwaiter), BindingFlagsPublicInstance)
.ReturnType
.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlagsPublicInstance);

if (getResultMethod == null)
{
WorkloadMethodReturnType = methodReturnType;
}
else
{
WorkloadMethodReturnType = getResultMethod.ReturnType;
GetAwaiterMethod = getAwaiterMethod;
GetResultMethod = getResultMethod;
}
.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlagsPublicInstance)
.ReturnType;
GetResultMethod = Helpers.AwaitHelper.GetGetResultMethod(methodReturnType);
}

if (WorkloadMethodReturnType == null)
Expand Down Expand Up @@ -74,14 +72,13 @@ public ConsumableTypeInfo(Type methodReturnType)
public Type WorkloadMethodReturnType { get; }
public Type OverheadMethodReturnType { get; }

public MethodInfo? GetAwaiterMethod { get; }
public MethodInfo? GetResultMethod { get; }

public bool IsVoid { get; }
public bool IsByRef { get; }
public bool IsConsumable { get; }
public FieldInfo? WorkloadConsumableField { get; }

public bool IsAwaitable => GetAwaiterMethod != null && GetResultMethod != null;
public bool IsAwaitable { get; }
}
}
Loading

0 comments on commit d830775

Please sign in to comment.