Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/dev.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include <chrono>
#include <string>
#include <map>

#ifndef __DEV__
#define __DEV__

using namespace Rcpp;

namespace dev{


/* ---- Profiling tools ------------------------- */

typedef std::map<std::string, std::chrono::time_point<std::chrono::high_resolution_clock> > Timer;

inline void start_timer(std::string label, Timer &timer){
auto now = std::chrono::high_resolution_clock::now();
timer[label] = now;
}

inline void stop_timer(std::string label, Timer &timer){
if (timer.find(label) == timer.end()){
Rcout << std::left << std::setw(20) << "'" + label + "'";
Rcout << " is not timed\n";
}else{
Rcout << std::left << std::setw(20) << "'" + label + "'";
Rcout << " ";
auto now = std::chrono::high_resolution_clock::now();
Rcout << std::chrono::duration<double, std::milli>(now - timer[label]).count();
Rcout << " millsec\n";
}
}
}

#endif
49 changes: 33 additions & 16 deletions src/wordvector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <mutex>
#include "word2vec/word2vec.hpp"
#include "tokens.h"
#include "dev.h"

typedef XPtr<TokensObj> TokensPtr;
typedef std::vector<std::string> vocabulary_t;
Expand All @@ -20,32 +21,48 @@ Rcpp::CharacterVector encode(std::vector<std::string> types){
return types_;
}

Rcpp::NumericMatrix as_matrix(std::vector<float> mat,
std::size_t nrow, std::size_t ncol) {

if (mat.size() == 0)
return Rcpp::NumericMatrix();
if (nrow * ncol != mat.size())
throw std::runtime_error("Invalid matrix size");
Rcpp::NumericMatrix mat_(nrow, ncol);
for (std::size_t i = 0; i < nrow; ++i) {
for (std::size_t j = 0; j < ncol; ++j) {
mat_(i, j) = mat[i * ncol + j];
}
}
return mat_;
}


Rcpp::NumericMatrix get_weights(w2v::word2vec_t model) {
std::vector<float> mat = model.weights();
if (model.vectorSize() * model.vocabularySize() != mat.size())
throw std::runtime_error("Invalid weight matrix");
Rcpp::NumericMatrix mat_ = as_matrix(mat, model.vocabularySize(), model.vectorSize());
rownames(mat_) = encode(model.vocabulary());
return mat_;
}

Rcpp::NumericMatrix get_words(w2v::word2vec_t model) {
std::vector<float> mat = model.values();
if (model.vectorSize() * model.vocabularySize() != mat.size())
throw std::runtime_error("Invalid word matrix");
Rcpp::NumericMatrix mat_(model.vectorSize(), model.vocabularySize(), mat.begin());
colnames(mat_) = encode(model.vocabulary());
return Rcpp::transpose(mat_);
Rcpp::NumericMatrix mat_ = as_matrix(mat, model.vocabularySize(), model.vectorSize());
rownames(mat_) = encode(model.vocabulary());
return mat_;
}

Rcpp::NumericMatrix get_documents(w2v::word2vec_t model) {
std::vector<float> mat = model.docValues();
if (mat.size() == 0)
return Rcpp::NumericMatrix();
if (model.vectorSize() * model.corpusSize() != mat.size())
throw std::runtime_error("Invalid document matrix");
Rcpp::NumericMatrix mat_(model.vectorSize(), model.corpusSize(), mat.begin());
return Rcpp::transpose(mat_);
}

Rcpp::NumericMatrix get_weights(w2v::word2vec_t model) {
std::vector<float> mat = model.weights();
if (model.vectorSize() * model.vocabularySize() != mat.size())
throw std::runtime_error("Invalid weight matrix");
Rcpp::NumericMatrix mat_(model.vectorSize(), model.vocabularySize(), mat.begin());
colnames(mat_) = encode(model.vocabulary());
return Rcpp::transpose(mat_);
Rcpp::NumericMatrix mat_ = as_matrix(mat, model.corpusSize(), model.vectorSize());
// TODO: add document names here
return mat_;
}

Rcpp::NumericVector get_frequency(w2v::corpus_t corpus) {
Expand Down
Loading