@@ -95,9 +95,9 @@ similarity <- function(x, targets, layer = c("words", "documents"),
9595# ' Compute probability of words
9696# '
9797# ' Compute the probability of words given other words.
98- # ' @param x a `textmodel_wordvector` object fitted with `normalize = FALSE` .
99- # ' @param targets words for which probability is computed.
100- # ' @param layer the layer based on which probability is computed.
98+ # ' @param x a trained `textmodel_wordvector` object.
99+ # ' @param targets words for which probabilities are computed.
100+ # ' @param layer the layer based on which probabilities are computed.
101101# ' @param mode specify the type of resulting object.
102102# ' @return a matrix of words or documents sorted in descending order by the probability
103103# ' scores when `mode = "character"`; a matrix of the probability scores when `mode = "numeric"`.
@@ -164,6 +164,31 @@ probability <- function(x, targets, layer = c("words", "documents"),
164164 return (res )
165165}
166166
167+ # ' Compute perplexity of a model
168+ # '
169+ # ' Compute the perplexity of a trained word2vec model with data.
170+ # ' @param x a trained `textmodel_wordvector` object.
171+ # ' @param targets words for which probabilities are computed.
172+ # ' @param data a [quanteda::tokens] or [quanteda::dfm]; the probabilities of words are
173+ # ' tested against occurrences of words in it.
174+ # ' @export
175+ # ' @keywords internal
176+ perplexity <- function (x , targets , data ) {
177+ x <- upgrade_pre06(x )
178+
179+ if (! is.character(targets ))
180+ stop(" targets must be a character vector" )
181+
182+ if (! is.tokens(data ) && ! is.dfm(data ))
183+ stop(" data must be a tokens or dfm" )
184+ data <- dfm(data , remove_padding = TRUE , tolower = x $ tolower )
185+
186+ p <- probability(x , targets , mode = " numeric" )
187+ pred <- dfm_match(dfm_weight(data , " prop" ), rownames(p )) %*% p
188+ tri <- Matrix :: mat2triplet(dfm_match(data , colnames(pred )))
189+ exp(- sum(tri $ x * log(pred [cbind(tri $ i , tri $ j )])) / sum(tri $ x ))
190+ }
191+
167192get_threads <- function () {
168193
169194 # respect other settings
@@ -194,6 +219,9 @@ upgrade_pre06 <- function(x) {
194219 if (is.numeric(x $ type )) {
195220 x $ type <- c(" cbow" , " sg" )[x $ type ]
196221 }
222+ if (is.null(x $ tolower )) {
223+ x $ tolower <- TRUE
224+ }
197225 return (x )
198226}
199227
0 commit comments