Skip to content

Commit 070fc81

Browse files
authored
Merge pull request #44 from koheiw/dev-perplexity
Dev perplexity
2 parents b297291 + 46c0d81 commit 070fc81

File tree

13 files changed

+137
-34
lines changed

13 files changed

+137
-34
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ S3method(textmodel_lsa,tokens)
1515
S3method(textmodel_word2vec,tokens)
1616
export(analogy)
1717
export(as.textmodel_doc2vec)
18+
export(perplexity)
1819
export(probability)
1920
export(similarity)
2021
export(textmodel_doc2vec)

R/as.doc2vec.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ as.textmodel_doc2vec.dfm <- function(x, model = NULL, normalize = FALSE,
6565
result <- list(
6666
"values" = list("word" = wov, "doc" = dov),
6767
"dim" = model$dim,
68+
"tolower" = model$tolower,
6869
"concatenator" = conc,
6970
"docvars" = x@docvars,
7071
"normalize" = normalize,

R/lsa.R

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,16 @@ textmodel_lsa.dfm <- function(x, dim = 50L, min_count = 5L,
9090
rownames(wov) <- featnames(x)
9191
}
9292
result <- list(
93-
values = list(word = wov),
94-
dim = dim,
95-
frequency = featfreq(x),
96-
engine = engine,
97-
weight = weight,
98-
min_count = min_count,
99-
concatenator = meta(x, field = "concatenator", type = "object"),
100-
call = try(match.call(sys.function(-1), call = sys.call(-1)), silent = TRUE),
101-
version = utils::packageVersion("wordvector")
93+
"values" = list(word = wov),
94+
"dim" = dim,
95+
"frequency" = featfreq(x),
96+
"engine" = engine,
97+
"weight" = weight,
98+
"min_count" = min_count,
99+
"tolower" = tolower,
100+
"concatenator" = meta(x, field = "concatenator", type = "object"),
101+
"call" = try(match.call(sys.function(-1), call = sys.call(-1)), silent = TRUE),
102+
"version" = utils::packageVersion("wordvector")
102103
)
103104
class(result) <- c("textmodel_lsa", "textmodel_wordvector")
104105
return(result)

R/utils.R

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ similarity <- function(x, targets, layer = c("words", "documents"),
9595
#' Compute probability of words
9696
#'
9797
#' Compute the probability of words given other words.
98-
#' @param x a `textmodel_wordvector` object fitted with `normalize = FALSE`.
99-
#' @param targets words for which probability is computed.
100-
#' @param layer the layer based on which probability is computed.
98+
#' @param x a trained `textmodel_wordvector` object.
99+
#' @param targets words for which probabilities are computed.
100+
#' @param layer the layer based on which probabilities are computed.
101101
#' @param mode specify the type of resulting object.
102102
#' @return a matrix of words or documents sorted in descending order by the probability
103103
#' scores when `mode = "character"`; a matrix of the probability scores when `mode = "numeric"`.
@@ -164,6 +164,31 @@ probability <- function(x, targets, layer = c("words", "documents"),
164164
return(res)
165165
}
166166

167+
#' Compute perplexity of a model
168+
#'
169+
#' Compute the perplexity of a trained word2vec model with data.
170+
#' @param x a trained `textmodel_wordvector` object.
171+
#' @param targets words for which probabilities are computed.
172+
#' @param data a [quanteda::tokens] or [quanteda::dfm]; the probabilities of words are
173+
#' tested against occurrences of words in it.
174+
#' @export
175+
#' @keywords internal
176+
perplexity <- function(x, targets, data) {
177+
x <- upgrade_pre06(x)
178+
179+
if (!is.character(targets))
180+
stop("targets must be a character vector")
181+
182+
if (!is.tokens(data) && !is.dfm(data))
183+
stop("data must be a tokens or dfm")
184+
data <- dfm(data, remove_padding = TRUE, tolower = x$tolower)
185+
186+
p <- probability(x, targets, mode = "numeric")
187+
pred <- dfm_match(dfm_weight(data, "prop"), rownames(p)) %*% p
188+
tri <- Matrix::mat2triplet(dfm_match(data, colnames(pred)))
189+
exp(-sum(tri$x * log(pred[cbind(tri$i, tri$j)])) / sum(tri$x))
190+
}
191+
167192
get_threads <- function() {
168193

169194
# respect other settings
@@ -194,6 +219,9 @@ upgrade_pre06 <- function(x) {
194219
if (is.numeric(x$type)) {
195220
x$type <- c("cbow", "sg")[x$type]
196221
}
222+
if (is.null(x$tolower)) {
223+
x$tolower <- TRUE
224+
}
197225
return(x)
198226
}
199227

R/word2vec.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ wordvector <- function(x, dim = 50, type = c("cbow", "sg", "dm", "dbow"),
160160

161161
result$type <- type
162162
result$min_count <- min_count
163+
result$tolower <- tolower
163164
result$concatenator <- meta(x, field = "concatenator", type = "object")
164165
if (include_data) # NOTE: consider removing
165166
result$data <- y

man/perplexity.Rd

Lines changed: 20 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/probability.Rd

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/misc/test_small.R

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,28 +30,36 @@ wdv$weights["america",]
3030
wdv2$weights["america",]
3131
wdv3$weights["america",]
3232

33-
#dov <- textmodel_doc2vec(toks, dim = 50, type = "cbow", min_count = 5, verbose = TRUE, iter = 10)
34-
dov <- textmodel_doc2vec(toks, dim = 50, type = "skip-gram", min_count = 5, verbose = TRUE, iter = 10)
33+
dov <- textmodel_doc2vec(toks, dim = 100, type = "dm", min_count = 5, verbose = TRUE, iter = 10)
34+
dov2 <- textmodel_doc2vec(toks, dim = 100, type = "dm", min_count = 5, verbose = TRUE, iter = 20)
3535

36-
similarity(dov, analogy(~ washington - america + france)) %>%
37-
head()
36+
sim <- proxyC::simil(
37+
dov$values$doc,
38+
dov$values$doc["3555430",, drop = FALSE]
39+
)
40+
tail(sort(s <- rowSums(sim)))
3841

39-
probability(dov, c("good")) %>%
40-
head()
42+
sim2 <- proxyC::simil(
43+
dov2$values$doc,
44+
dov2$values$doc["3555430",, drop = FALSE]
45+
)
46+
tail(sort(s2 <- rowSums(sim2)))
47+
48+
identical(s, s2)
49+
cor(s, s2)
4150

4251
sim <- proxyC::simil(
43-
dov$values$doc,
44-
dov$values$doc["4263794",, drop = FALSE]
52+
dov2$values$doc,
53+
dov2$values$doc["4263794",, drop = FALSE]
4554
)
4655
sim <- proxyC::simil(
47-
dov$values$doc,
48-
dov$values$doc["3016236",, drop = FALSE]
56+
dov2$values$doc,
57+
dov2$values$doc["3016236",, drop = FALSE]
4958
)
5059
sim <- proxyC::simil(
51-
dov$values$doc,
52-
dov$values$doc["3555430",, drop = FALSE]
60+
dov2$values$doc,
61+
dov2$values$doc["3555430",, drop = FALSE]
5362
)
54-
5563
tail(sort(s <- rowSums(sim)))
5664
print(tail(toks[order(s)]), max_ntoken = -1)
5765

tests/testthat/test-as.doc2vec.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ test_that("textmodel_doc2vec works", {
2020
expect_false(dov1$normalize)
2121
expect_equal(
2222
names(dov1),
23-
c("values", "dim", "concatenator", "docvars", "normalize", "call", "version")
23+
c("values", "dim", "tolower", "concatenator", "docvars", "normalize", "call", "version")
2424
)
2525
expect_equal(
2626
dim(dov1$values$word), c(5363L, 50L)

tests/testthat/test-doc2vec.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ test_that("textmodel_doc2vec works", {
2121
expect_equal(
2222
names(dov1),
2323
c("values", "weights", "type", "dim", "frequency", "window", "iter", "alpha",
24-
"use_ns", "ns_size", "sample", "normalize", "min_count", "concatenator", "docvars", "call", "version")
24+
"use_ns", "ns_size", "sample", "normalize", "min_count", "tolower",
25+
"concatenator", "docvars", "call", "version")
2526
)
2627
expect_equal(
2728
dim(dov1$values$word), c(5363L, 50L)
@@ -64,7 +65,8 @@ test_that("textmodel_doc2vec works", {
6465
expect_equal(
6566
names(dov2),
6667
c("values", "weights", "type", "dim", "frequency", "window", "iter", "alpha",
67-
"use_ns", "ns_size", "sample", "normalize", "min_count", "concatenator", "docvars", "call", "version")
68+
"use_ns", "ns_size", "sample", "normalize", "min_count", "tolower",
69+
"concatenator", "docvars", "call", "version")
6870
)
6971
expect_null(
7072
dov2$values$word

0 commit comments

Comments
 (0)