Skip to content

Commit 2125cd8

Browse files
committed
Rename variables
1 parent ac0f632 commit 2125cd8

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

src/word2vec/trainThread.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -324,14 +324,14 @@ namespace w2v {
324324
}
325325

326326
// compute gradient x alpha
327-
auto error = (1.0f - static_cast<float>(huffmanData->huffmanCode[i]) - f) * (*m_data.alpha);
327+
auto gxa = (1.0f - static_cast<float>(huffmanData->huffmanCode[i]) - f) * (*m_data.alpha);
328328
// propagate errors output -> hidden
329329
for (std::size_t k = 0; k < K; ++k) {
330-
_hiddenLayer[k] += error * (*m_data.bpWeights)[k + shift];
330+
_hiddenLayer[k] += gxa * (*m_data.bpWeights)[k + shift];
331331
}
332332
// learn weights hidden -> output
333333
for (std::size_t k = 0; k < K; ++k) {
334-
(*m_data.bpWeights)[k + shift] += error * _trainLayer[k + _trainLayerShift];
334+
(*m_data.bpWeights)[k + shift] += gxa * _trainLayer[k + _trainLayerShift];
335335
}
336336
}
337337
}
@@ -346,9 +346,11 @@ namespace w2v {
346346
std::size_t target = 0;
347347
bool label = false;
348348
if (i == 0) {
349+
// positive case
349350
target = _index;
350351
label = true;
351352
} else {
353+
// negative case
352354
target = (*m_nsDistribution)(m_randomGenerator);
353355
if (target == _index) {
354356
continue;
@@ -372,15 +374,15 @@ namespace w2v {
372374
}
373375

374376
// compute gradient x alpha
375-
auto error = (static_cast<float>(label) - f) * (*m_data.alpha);
376-
std::cout << _index << ", " << target << ", " << error << "\n";
377+
auto gxa = (static_cast<float>(label) - f) * (*m_data.alpha); // gxa > 0 in the positive case
378+
//std::cout << i << ": " << _index << ", " << target << ", " << gxa << "\n";
377379
// propagate errors output -> hidden
378380
for (std::size_t k = 0; k < K; ++k) {
379-
_hiddenLayer[k] += error * (*m_data.bpWeights)[k + shift];
381+
_hiddenLayer[k] += gxa * (*m_data.bpWeights)[k + shift];
380382
}
381383
// learn weights hidden -> output
382384
for (std::size_t k = 0; k < m_data.settings->size; ++k) {
383-
(*m_data.bpWeights)[k + shift] += error * _trainLayer[k + _trainLayerShift];
385+
(*m_data.bpWeights)[k + shift] += gxa * _trainLayer[k + _trainLayerShift];
384386
}
385387
}
386388
}

0 commit comments

Comments
 (0)