Skip to content

Commit e09c045

Browse files
committed
change multinomial naive bayes algorithm
1 parent 6262de8 commit e09c045

File tree

2 files changed

+185
-52
lines changed

2 files changed

+185
-52
lines changed

src/Abstracts/NaiveBayesClassifier.cs

+81-7
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,25 @@ public abstract class NaiveBayesClassifier : INaiveBayesClassifier
66
{
77
protected Dictionary<string, Dictionary<string, int>> wordCountsPerLabel;
88
protected Dictionary<string, int> totalWordsPerLabel;
9+
protected readonly int maxSamplesPerClass;
10+
protected readonly bool useUndersampling;
11+
protected readonly Random random;
912

10-
public NaiveBayesClassifier()
13+
public NaiveBayesClassifier(int maxSamplesPerClass = 1000, bool useUndersampling = true)
1114
{
1215
wordCountsPerLabel = new Dictionary<string, Dictionary<string, int>>();
1316
totalWordsPerLabel = new Dictionary<string, int>();
17+
this.maxSamplesPerClass = maxSamplesPerClass;
18+
this.useUndersampling = useUndersampling;
19+
this.random = new Random();
1420
}
1521

1622
public virtual void Train(IEnumerable<ClassifierModel> trainingData)
1723
{
18-
foreach (var data in trainingData)
24+
var balancedData = BalanceDataset(trainingData.ToList());
25+
26+
foreach (var data in balancedData)
1927
{
20-
// Initialize dictionary for new labels
2128
if (!wordCountsPerLabel.ContainsKey(data.Label))
2229
{
2330
wordCountsPerLabel[data.Label] = new Dictionary<string, int>();
@@ -33,16 +40,83 @@ public virtual void Train(IEnumerable<ClassifierModel> trainingData)
3340
targetDictionary[word] = 0;
3441
targetDictionary[word]++;
3542
}
36-
3743
totalWordsPerLabel[data.Label]++;
3844
}
3945
}
4046

41-
protected IEnumerable<string> Tokenize(string text, char separator = ' ')
47+
protected virtual IEnumerable<ClassifierModel> BalanceDataset(List<ClassifierModel> data)
48+
{
49+
// Group data by label
50+
var groupedData = data.GroupBy(x => x.Label)
51+
.ToDictionary(g => g.Key, g => g.ToList());
52+
53+
// Find minority and majority class sizes
54+
int minClassSize = groupedData.Values.Min(x => x.Count);
55+
int maxClassSize = groupedData.Values.Max(x => x.Count);
56+
57+
// Determine target size based on strategy
58+
int targetSize = useUndersampling ?
59+
Math.Min(minClassSize, maxSamplesPerClass) :
60+
Math.Min(maxClassSize, maxSamplesPerClass);
61+
62+
var balancedData = new List<ClassifierModel>();
63+
64+
foreach (var group in groupedData)
65+
{
66+
var samples = group.Value;
67+
var currentSize = samples.Count;
68+
69+
if (currentSize <= targetSize)
70+
{
71+
// If using oversampling and current size is less than target
72+
if (!useUndersampling && currentSize < targetSize)
73+
{
74+
balancedData.AddRange(OversampleData(samples, targetSize));
75+
}
76+
else
77+
{
78+
balancedData.AddRange(samples);
79+
}
80+
}
81+
else
82+
{
83+
// Undersample if current size is greater than target
84+
balancedData.AddRange(UndersampleData(samples, targetSize));
85+
}
86+
}
87+
88+
return balancedData;
89+
}
90+
91+
protected virtual IEnumerable<ClassifierModel> UndersampleData(List<ClassifierModel> samples, int targetSize)
92+
{
93+
return samples.OrderBy(x => random.Next()).Take(targetSize);
94+
}
95+
96+
protected virtual IEnumerable<ClassifierModel> OversampleData(List<ClassifierModel> samples, int targetSize)
97+
{
98+
var result = new List<ClassifierModel>(samples);
99+
100+
while (result.Count < targetSize)
101+
{
102+
// Add random samples from the original set until we reach target size
103+
result.Add(samples[random.Next(samples.Count)]);
104+
}
105+
106+
return result;
107+
}
108+
109+
protected virtual IEnumerable<string> Tokenize(string text, char separator = ' ')
42110
{
43111
return text.ToLower().Split(separator);
44112
}
45-
46-
public abstract string Predict(string text);
113+
114+
public Dictionary<string, int> GetClassDistribution(IEnumerable<ClassifierModel> data)
115+
{
116+
return data.GroupBy(x => x.Label)
117+
.ToDictionary(g => g.Key, g => g.Count());
118+
}
119+
120+
public abstract Dictionary<string, double> Predict(string text);
47121
}
48122
}
+104-45
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,139 @@
11
using MyML.Abstracts;
2-
using MyML.Interfaces;
32

43
namespace MyML
54
{
65
public class MultinomialNaiveBayesClassifier : NaiveBayesClassifier
76
{
8-
public override string Predict(string text)
7+
private int _vocabularySize;
8+
9+
public MultinomialNaiveBayesClassifier()
910
{
10-
double maxProbability = double.MinValue;
11+
CalculateVocabularySize();
12+
}
13+
14+
/// <summary>
15+
/// Predicts class probabilities for a given text input using Multinomial Naive Bayes classification.
16+
/// Returns normalized probabilities (as percentages) for each class label.
17+
/// </summary>
18+
/// <remarks>
19+
/// The prediction process consists of four main steps:
20+
///
21+
/// 1. Calculate posterior probabilities in log space:
22+
/// - Combines class prior probability with word likelihoods
23+
/// - P(class|text) ∝ log(P(class)) + Σ log(P(word|class))
24+
///
25+
/// 2. Find maximum log probability for numerical stability
26+
/// - Used for log-sum-exp trick to prevent overflow
27+
///
28+
/// 3. Convert log probabilities to normal space:
29+
/// - Uses log-sum-exp trick to prevent numerical overflow
30+
/// - Shifts all log probabilities by subtracting max value
31+
/// - exp(log(p) - maxLogP) / Σ exp(log(p) - maxLogP)
32+
///
33+
/// 4. Normalize probabilities to percentages:
34+
/// - Ensures all probabilities sum to 100%
35+
/// </remarks>
36+
/// <param name="text">Input text to classify</param>
37+
/// <returns>Dictionary mapping class labels to their predicted probabilities (as percentages)</returns>
38+
public override Dictionary<string, double> Predict(string text)
39+
{
40+
double maxLogProbability = double.MinValue;
1141
string? predictedLabel = null;
1242
int totalWordCount = totalWordsPerLabel.Values.Sum();
43+
IEnumerable<string> words = Tokenize(text);
44+
Dictionary<string, double> logProbabilities = new Dictionary<string, double>();
1345

14-
var words = Tokenize(text);
1546
foreach (var label in wordCountsPerLabel.Keys)
1647
{
1748
var labelWordCounts = wordCountsPerLabel[label];
1849
var totalClassCount = totalWordsPerLabel[label];
1950

20-
var probability = CalculateProbability(words, labelWordCounts, totalClassCount, totalWordCount);
21-
var evidence = CalculateEvidence(words, wordCountsPerLabel, totalWordCount);
22-
var labelProbability = probability / evidence;
51+
double logLikelihood = CalculateProbability(
52+
words,
53+
labelWordCounts,
54+
totalClassCount,
55+
_vocabularySize
56+
);
2357

24-
if (labelProbability > maxProbability)
58+
double logPrior = Math.Log((double)totalClassCount / totalWordCount);
59+
double logPosterior = logLikelihood + logPrior;
60+
logProbabilities.Add(label, logPosterior);
61+
if (logPosterior > maxLogProbability)
2562
{
26-
maxProbability = labelProbability;
63+
maxLogProbability = logPosterior;
2764
predictedLabel = label;
2865
}
2966
}
3067

31-
return predictedLabel!;
32-
}
33-
68+
var result = new Dictionary<string, double>();
69+
double sumExp = 0.0;
3470

35-
private double CalculateProbability(IEnumerable<string> words, Dictionary<string, int> wordCounts, int totalClassCount, int totalWordCount)
36-
{
37-
double probability = 1;
71+
foreach (var kvp in logProbabilities)
72+
{
73+
double shiftedLogProb = kvp.Value - maxLogProbability;
74+
sumExp += Math.Exp(shiftedLogProb);
75+
}
3876

39-
foreach (var word in words)
77+
foreach (var kvp in logProbabilities)
4078
{
41-
// Laplace smoothing to add one just to ensure that every word contributes a small, non-zero probability
42-
if (wordCounts.TryGetValue(word, out var count))
43-
{
44-
probability *= (double)(count + 1) / (totalClassCount + totalWordCount);
45-
}
46-
else
47-
{
48-
probability *= 1.0 / (totalClassCount + totalWordCount);
49-
}
79+
double shiftedLogProb = kvp.Value - maxLogProbability;
80+
double normalizedProb = (Math.Exp(shiftedLogProb) / sumExp) * 100;
81+
result.Add(kvp.Key, normalizedProb);
5082
}
51-
return probability;
83+
84+
return result;
5285
}
86+
5387
/// <summary>
54-
/// Computes the evidence term P(B) in Bayes' Theorem, which is the probability of the observed features (words, in this case) across all classes
88+
/// Calculates the log probability of a document belonging to a specific class using
89+
/// the Multinomial Naive Bayes algorithm with Laplace (add-one) smoothing.
5590
/// </summary>
56-
/// <param name="words"></param>
57-
/// <param name="wordCountsPerLabel"></param>
58-
/// <param name="totalWordCount"></param>
59-
/// <returns></returns>
60-
private double CalculateEvidence(IEnumerable<string> words, Dictionary<string, Dictionary<string, int>> wordCountsPerLabel, int totalWordCount)
91+
/// <remarks>
92+
///
93+
/// 1. Uses log probabilities to prevent numerical underflow
94+
/// 2. Applies Laplace smoothing to handle unseen words
95+
/// 3. Assumes word independence (naive assumption)
96+
///
97+
/// The probability is calculated as:
98+
/// P(class|document) ∝ log(P(class)) + Σ log(P(word|class))
99+
///
100+
/// Where P(word|class) is smoothed using Laplace smoothing:
101+
/// P(word|class) = (count(word,class) + 1) / (totalWords + vocabularySize)
102+
/// </remarks>
103+
/// <param name="words">Collection of words from the document to classify</param>
104+
/// <param name="wordCountsForClass">Dictionary containing word counts for the current class</param>
105+
/// <param name="totalWordsInClass">Total number of words in the training data for this class</param>
106+
/// <param name="vocabularySize">Size of the entire vocabulary across all classes</param>
107+
/// <returns>
108+
/// Log probability of the document belonging to the class. Higher values indicate
109+
/// stronger association with the class.
110+
/// </returns>
111+
private double CalculateProbability(
112+
IEnumerable<string> words,
113+
Dictionary<string, int> wordCountsForClass,
114+
int totalWordsInClass,
115+
int vocabularySize)
61116
{
62-
double evidence = 1;
63-
117+
double logProbability = 0.0;
64118
foreach (var word in words)
65119
{
66-
double wordProbability = 0;
67-
foreach (var label in wordCountsPerLabel.Keys)
68-
{
69-
if (wordCountsPerLabel[label].TryGetValue(word, out var count))
70-
{
71-
wordProbability += (double)count / totalWordCount;
72-
}
73-
}
74-
double dealUnseenWord = wordProbability > 0 ? wordProbability : 1.0;
75-
evidence *= dealUnseenWord / totalWordCount;
120+
int count = wordCountsForClass.TryGetValue(word, out int c) ? c : 0;
121+
122+
// Laplace smoothing in log space
123+
double smoothedProb = Math.Log((count + 1.0) / (totalWordsInClass + vocabularySize));
124+
logProbability += smoothedProb;
125+
}
126+
127+
return logProbability;
128+
}
129+
private void CalculateVocabularySize()
130+
{
131+
HashSet<string> uniqueWords = new();
132+
foreach (var labelDict in wordCountsPerLabel.Values)
133+
{
134+
uniqueWords.UnionWith(labelDict.Keys);
76135
}
77-
return evidence;
136+
_vocabularySize = uniqueWords.Count;
78137
}
79138
}
80139
}

0 commit comments

Comments
 (0)