Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 77 additions & 2 deletions src/Baballonia/Models/SliderBindableSetting.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
using CommunityToolkit.Mvvm.ComponentModel;
using Baballonia.Contracts;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.ComponentModel;
using Avalonia.Threading;

namespace Baballonia.Models;

Expand All @@ -18,7 +23,77 @@ public SliderBindableSetting(string name, float lower = 0f, float upper = 1f, fl
Name = name;
Lower = lower;
Upper = upper;
Min = max;
Max = min;
Min = min;
Max = max;
}
}

public partial class ParameterGroupCollection(
string groupName,
IFilterSettings filterSettings,
IEnumerable<SliderBindableSetting> items)
: ObservableCollection<SliderBindableSetting>(items)
{
public string GroupName { get; } = groupName;
public IFilterSettings FilterSettings { get; } = filterSettings;
}

public interface IFilterSettings : INotifyPropertyChanged
{
bool Enabled { get; set; }
float MinFreqCutoff { get; set; }
float SpeedCutoff { get; set; }
}

public partial class GroupFilterSettings : ObservableObject, IFilterSettings
{
private readonly ILocalSettingsService _localSettingsService;
private readonly string _prefix;

[ObservableProperty]
private bool _enabled;

[ObservableProperty]
private float _minFreqCutoff;

[ObservableProperty]
private float _speedCutoff;

public GroupFilterSettings(ILocalSettingsService localSettingsService, string settingPrefix,
bool defaultEnabled, float defaultMinFreqCutoff, float defaultSpeedCutoff)
{
_localSettingsService = localSettingsService;
_prefix = settingPrefix;

Enabled = defaultEnabled;
MinFreqCutoff = defaultMinFreqCutoff;
SpeedCutoff = defaultSpeedCutoff;

var enabled = _localSettingsService.ReadSetting($"{_prefix}_Enabled", defaultEnabled);
var min = _localSettingsService.ReadSetting($"{_prefix}_MinFreq", defaultMinFreqCutoff);
var speed = _localSettingsService.ReadSetting($"{_prefix}_Speed", defaultSpeedCutoff);

Dispatcher.UIThread.Post(() =>
{
Enabled = enabled;
MinFreqCutoff = min;
SpeedCutoff = speed;
});

PropertyChanged += (_, e) =>
{
switch (e.PropertyName)
{
case nameof(Enabled):
_localSettingsService.SaveSetting($"{_prefix}_Enabled", Enabled);
break;
case nameof(MinFreqCutoff):
_localSettingsService.SaveSetting($"{_prefix}_MinFreq", MinFreqCutoff);
break;
case nameof(SpeedCutoff):
_localSettingsService.SaveSetting($"{_prefix}_Speed", SpeedCutoff);
break;
}
};
}
}
8 changes: 1 addition & 7 deletions src/Baballonia/Services/EyePipelineManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,7 @@ public void LoadFilter()
if (!enabled)
return;

var eyeArray = new float[Utils.EyeRawExpressions];
var eyeFilter = new OneEuroFilter(
eyeArray,
minCutoff: cutoff,
beta: speedCutoff
);

var eyeFilter = new GroupedOneEuroFilter();
_pipeline.Filter = eyeFilter;
}

Expand Down
8 changes: 1 addition & 7 deletions src/Baballonia/Services/Inference/FacePipelineManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,7 @@ public void LoadFilter()
if (!enabled)
return;

var faceArray = new float[Utils.FaceRawExpressions];
var faceFilter = new OneEuroFilter(
faceArray,
minCutoff: cutoff,
beta: speedCutoff
);

var faceFilter = new GroupedOneEuroFilter();
_pipeline.Filter = faceFilter;
}

Expand Down
159 changes: 84 additions & 75 deletions src/Baballonia/Services/Inference/Filters/OneEuroFilter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,100 +3,109 @@

namespace Baballonia.Services.Inference.Filters;

public class OneEuroFilter : IFilter
public class GroupedOneEuroFilter : IFilter
{
private float[] minCutoff;
private float[] beta;
private float[] dCutoff;
private float[] xPrev;
private float[] dxPrev;
private DateTime tPrev;
public OneEuroFilter(float[] x0, float minCutoff = 1.0f, float beta = 0.0f)
private sealed class GroupState
{
float dx0 = 0.0f;
float dCutoff = 1.0f;
int length = x0.Length;
this.minCutoff = CreateFilledArray(length, minCutoff);
this.beta = CreateFilledArray(length, beta);
this.dCutoff = CreateFilledArray(length, dCutoff);
// Previous values.
this.xPrev = (float[])x0.Clone();
this.dxPrev = CreateFilledArray(length, dx0);
this.tPrev = DateTime.UtcNow;
public int[] Indices = Array.Empty<int>();
public float[] XPrev = Array.Empty<float>();
public float[] DxPrev = Array.Empty<float>();
public float MinCutoff;
public float Beta;
public float DCutoff = 1.0f;
public DateTime TPrev;
public bool Initialized;
}

public float[] Filter(float[] x)
{
if (x.Length != xPrev.Length)
throw new ArgumentException($"Input shape does not match initial shape. Expected: {xPrev.Length}, got: {x.Length}");

DateTime now = DateTime.UtcNow;
float elapsedTime = (float)(now - tPrev).TotalSeconds;

if (elapsedTime == 0.0f)
{
xPrev = (float[])x.Clone();
return x;
}

float[] t_e = CreateFilledArray(x.Length, elapsedTime);
private readonly Dictionary<string, GroupState> _groups = new();

// Derivative
float[] dx = new float[x.Length];
for (int i = 0; i < x.Length; i++)
{
dx[i] = (x[i] - xPrev[i]) / t_e[i];
}

float[] a_d = SmoothingFactor(t_e, dCutoff);
float[] dxHat = ExponentialSmoothing(a_d, dx, dxPrev);
public void ConfigureGroup(string groupName, int[] parameterIndices, float minCutoff, float beta)
{
if (parameterIndices.Length == 0)
return;

// Adjusted cutoff
float[] cutoff = new float[x.Length];
for (int i = 0; i < x.Length; i++)
var state = new GroupState
{
cutoff[i] = minCutoff[i] + beta[i] * Math.Abs(dxHat[i]);
}

float[] a = SmoothingFactor(t_e, cutoff);
float[] xHat = ExponentialSmoothing(a, x, xPrev);

// Store previous values
xPrev = xHat;
dxPrev = dxHat;
tPrev = now;

return xHat;
Indices = (int[])parameterIndices.Clone(),
XPrev = new float[parameterIndices.Length],
DxPrev = new float[parameterIndices.Length],
MinCutoff = Math.Max(0.001f, minCutoff),
Beta = Math.Max(0f, beta),
TPrev = DateTime.UtcNow,
Initialized = false
};

_groups[groupName] = state;
}

private float[] CreateFilledArray(int length, float value)
public void DisableGroup(string groupName)
{
float[] arr = new float[length];
for (int i = 0; i < length; i++) arr[i] = value;
return arr;
_groups.Remove(groupName);
}

private float[] SmoothingFactor(float[] t_e, float[] cutoff)
public float[] Filter(float[] input)
{
int length = t_e.Length;
float[] result = new float[length];
for (int i = 0; i < length; i++)
if (_groups.Count == 0)
return input;

var now = DateTime.UtcNow;
float[] result = (float[])input.Clone();

foreach (var kvp in _groups)
{
float r = 2 * (float)Math.PI * cutoff[i] * t_e[i];
result[i] = r / (r + 1);
var state = kvp.Value;
if (state.Indices.Length == 0)
continue;

int n = state.Indices.Length;
float[] x = new float[n];
var indices = state.Indices;
for (int i = 0; i < n; i++)
{
x[i] = input[indices[i]];
}

float dt = (float)(now - state.TPrev).TotalSeconds;
if (!state.Initialized || dt <= 0f)
{
for (int i = 0; i < n; i++)
state.XPrev[i] = x[i];
state.TPrev = now;
state.Initialized = true;
continue;
}

// dx = (x - xPrev) / dt
for (int i = 0; i < n; i++)
{
state.DxPrev[i] = OneEuroSmooth(state.DCutoff, dt, (x[i] - state.XPrev[i]) / dt, state.DxPrev[i]);
}

// cutoff = minCutoff + beta * |dxHat|
for (int i = 0; i < n; i++)
{
float cutoff = state.MinCutoff + state.Beta * MathF.Abs(state.DxPrev[i]);
float a = SmoothingFactor(cutoff, dt);
float xHat = a * x[i] + (1f - a) * state.XPrev[i];
state.XPrev[i] = xHat;
result[indices[i]] = xHat;
}

state.TPrev = now;
}

return result;
}

private float[] ExponentialSmoothing(float[] a, float[] x, float[] xPrev)
private static float OneEuroSmooth(float cutoff, float dt, float value, float prev)
{
int length = a.Length;
float[] result = new float[length];
for (int i = 0; i < length; i++)
{
result[i] = a[i] * x[i] + (1 - a[i]) * xPrev[i];
}
return result;
float a = SmoothingFactor(cutoff, dt);
return a * value + (1f - a) * prev;
}

private static float SmoothingFactor(float cutoff, float dt)
{
float r = 2f * MathF.PI * cutoff * dt;
return r / (r + 1f);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public class PlatformSettings(
Size inputSize,
InferenceSession session,
DenseTensor<float> tensor,
OneEuroFilter oneEuroFilter,
IFilter oneEuroFilter,
float lastTime,
string inputName,
string modelName)
Expand All @@ -19,7 +19,7 @@ public class PlatformSettings(
public InferenceSession Session { get; } = session;
public DenseTensor<float> Tensor { get; } = tensor;

public OneEuroFilter Filter { get; } = oneEuroFilter;
public IFilter Filter { get; } = oneEuroFilter;
public float LastTime { get; set; } = lastTime;
public string InputName { get; } = inputName;
public string ModelName { get; } = modelName;
Expand Down
12 changes: 0 additions & 12 deletions src/Baballonia/ViewModels/SplitViewPane/AppSettingsViewModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,6 @@ public partial class AppSettingsViewModel : ViewModelBase
[property: SavedSetting("AppSettings_OSCPrefix", "")]
private string _oscPrefix;

[ObservableProperty]
[property: SavedSetting("AppSettings_OneEuroEnabled", true)]
private bool _oneEuroMinEnabled;

[ObservableProperty]
[property: SavedSetting("AppSettings_OneEuroMinFreqCutoff", 1f)]
private float _oneEuroMinFreqCutoff;

[ObservableProperty]
[property: SavedSetting("AppSettings_OneEuroSpeedCutoff", 1f)]
private float _oneEuroSpeedCutoff;

[ObservableProperty]
[property: SavedSetting("AppSettings_UseGPU", true)]
private bool _useGPU;
Expand Down
Loading