@@ -324,14 +324,14 @@ namespace w2v {
324324 }
325325
326326 // compute gradient x alpha
327- auto error = (1 .0f - static_cast <float >(huffmanData->huffmanCode [i]) - f) * (*m_data.alpha );
327+ auto gxa = (1 .0f - static_cast <float >(huffmanData->huffmanCode [i]) - f) * (*m_data.alpha );
328328 // propagate errors output -> hidden
329329 for (std::size_t k = 0 ; k < K; ++k) {
330- _hiddenLayer[k] += error * (*m_data.bpWeights )[k + shift];
330+ _hiddenLayer[k] += gxa * (*m_data.bpWeights )[k + shift];
331331 }
332332 // learn weights hidden -> output
333333 for (std::size_t k = 0 ; k < K; ++k) {
334- (*m_data.bpWeights )[k + shift] += error * _trainLayer[k + _trainLayerShift];
334+ (*m_data.bpWeights )[k + shift] += gxa * _trainLayer[k + _trainLayerShift];
335335 }
336336 }
337337 }
@@ -346,9 +346,11 @@ namespace w2v {
346346 std::size_t target = 0 ;
347347 bool label = false ;
348348 if (i == 0 ) {
349+ // positive case
349350 target = _index;
350351 label = true ;
351352 } else {
353+ // negative case
352354 target = (*m_nsDistribution)(m_randomGenerator);
353355 if (target == _index) {
354356 continue ;
@@ -372,15 +374,15 @@ namespace w2v {
372374 }
373375
374376 // compute gradient x alpha
375- auto error = (static_cast <float >(label) - f) * (*m_data.alpha );
376- std::cout << _index << " , " << target << " , " << error << " \n " ;
377+ auto gxa = (static_cast <float >(label) - f) * (*m_data.alpha ); // gxa > 0 in the positive case
378+ // std::cout << i << ": " << _index << ", " << target << ", " << gxa << "\n";
377379 // propagate errors output -> hidden
378380 for (std::size_t k = 0 ; k < K; ++k) {
379- _hiddenLayer[k] += error * (*m_data.bpWeights )[k + shift];
381+ _hiddenLayer[k] += gxa * (*m_data.bpWeights )[k + shift];
380382 }
381383 // learn weights hidden -> output
382384 for (std::size_t k = 0 ; k < m_data.settings ->size ; ++k) {
383- (*m_data.bpWeights )[k + shift] += error * _trainLayer[k + _trainLayerShift];
385+ (*m_data.bpWeights )[k + shift] += gxa * _trainLayer[k + _trainLayerShift];
384386 }
385387 }
386388 }
0 commit comments