Skip to content

Commit 3347051

Browse files
authored
Merge pull request #46 from koheiw/dev-corssprod
Fix upgrade function
2 parents ed05bac + 09ede9f commit 3347051

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

R/utils.R

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ probability <- function(x, targets, layer = c("words", "documents"),
145145
targets <- targets[b]
146146

147147
values <- as.matrix(x, layer = layer, normalize = FALSE)
148-
e <- exp(values %*% t(x$weights[names(targets),, drop = FALSE]))
148+
e <- exp(tcrossprod(values, x$weights[names(targets),, drop = FALSE]))
149149
prob <- e / (e + 1) # sigmoid function
150150

151151
res <- prob %*% diag(targets)
@@ -184,7 +184,7 @@ perplexity <- function(x, targets, data) {
184184
data <- dfm(data, remove_padding = TRUE, tolower = x$tolower)
185185

186186
p <- probability(x, targets, mode = "numeric")
187-
pred <- dfm_match(dfm_weight(data, "prop"), rownames(p)) %*% p
187+
pred <- crossprod(t(dfm_match(dfm_weight(data, "prop"), rownames(p))), p)
188188
tri <- Matrix::mat2triplet(dfm_match(data, colnames(pred)))
189189
exp(-sum(tri$x * log(pred[cbind(tri$i, tri$j)])) / sum(tri$x))
190190
}
@@ -207,6 +207,12 @@ get_threads <- function() {
207207

208208
upgrade_pre06 <- function(x) {
209209

210+
if (is.null(x$tolower)) {
211+
x$tolower <- TRUE
212+
}
213+
if (is.numeric(x$type)) {
214+
x$type <- c("cbow", "sg")[x$type]
215+
}
210216
if (is.list(x$values))
211217
return(x)
212218
if (identical(class(x), "textmodel_wordvector")) {
@@ -216,12 +222,6 @@ upgrade_pre06 <- function(x) {
216222
x$values <- list(doc = x$values)
217223
class(x) <- c("textmodel_doc2vec", "textmodel_wordvector")
218224
}
219-
if (is.numeric(x$type)) {
220-
x$type <- c("cbow", "sg")[x$type]
221-
}
222-
if (is.null(x$tolower)) {
223-
x$tolower <- TRUE
224-
}
225225
return(x)
226226
}
227227

0 commit comments

Comments
 (0)