Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ S3method(textmodel_lsa,tokens)
S3method(textmodel_word2vec,tokens)
export(analogy)
export(as.textmodel_doc2vec)
export(perplexity)
export(probability)
export(similarity)
export(textmodel_doc2vec)
Expand Down
1 change: 1 addition & 0 deletions R/as.doc2vec.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ as.textmodel_doc2vec.dfm <- function(x, model = NULL, normalize = FALSE,
result <- list(
"values" = list("word" = wov, "doc" = dov),
"dim" = model$dim,
"tolower" = model$tolower,
"concatenator" = conc,
"docvars" = x@docvars,
"normalize" = normalize,
Expand Down
19 changes: 10 additions & 9 deletions R/lsa.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,16 @@ textmodel_lsa.dfm <- function(x, dim = 50L, min_count = 5L,
rownames(wov) <- featnames(x)
}
result <- list(
values = list(word = wov),
dim = dim,
frequency = featfreq(x),
engine = engine,
weight = weight,
min_count = min_count,
concatenator = meta(x, field = "concatenator", type = "object"),
call = try(match.call(sys.function(-1), call = sys.call(-1)), silent = TRUE),
version = utils::packageVersion("wordvector")
"values" = list(word = wov),
"dim" = dim,
"frequency" = featfreq(x),
"engine" = engine,
"weight" = weight,
"min_count" = min_count,
"tolower" = tolower,
"concatenator" = meta(x, field = "concatenator", type = "object"),
"call" = try(match.call(sys.function(-1), call = sys.call(-1)), silent = TRUE),
"version" = utils::packageVersion("wordvector")
)
class(result) <- c("textmodel_lsa", "textmodel_wordvector")
return(result)
Expand Down
34 changes: 31 additions & 3 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ similarity <- function(x, targets, layer = c("words", "documents"),
#' Compute probability of words
#'
#' Compute the probability of words given other words.
#' @param x a `textmodel_wordvector` object fitted with `normalize = FALSE`.
#' @param targets words for which probability is computed.
#' @param layer the layer based on which probability is computed.
#' @param x a trained `textmodel_wordvector` object.
#' @param targets words for which probabilities are computed.
#' @param layer the layer based on which probabilities are computed.
#' @param mode specify the type of resulting object.
#' @return a matrix of words or documents sorted in descending order by the probability
#' scores when `mode = "character"`; a matrix of the probability scores when `mode = "numeric"`.
Expand Down Expand Up @@ -164,6 +164,31 @@ probability <- function(x, targets, layer = c("words", "documents"),
return(res)
}

#' Compute perplexity of a model
#'
#' Compute the perplexity of a trained word2vec model with data.
#' @param x a trained `textmodel_wordvector` object.
#' @param targets words for which probabilities are computed.
#' @param data a [quanteda::tokens] or [quanteda::dfm]; the probabilities of words are
#' tested against occurrences of words in it.
#' @export
#' @keywords internal
perplexity <- function(x, targets, data) {
x <- upgrade_pre06(x)

if (!is.character(targets))
stop("targets must be a character vector")

if (!is.tokens(data) && !is.dfm(data))
stop("data must be a tokens or dfm")
data <- dfm(data, remove_padding = TRUE, tolower = x$tolower)

p <- probability(x, targets, mode = "numeric")
pred <- dfm_match(dfm_weight(data, "prop"), rownames(p)) %*% p
tri <- Matrix::mat2triplet(dfm_match(data, colnames(pred)))
exp(-sum(tri$x * log(pred[cbind(tri$i, tri$j)])) / sum(tri$x))
}

get_threads <- function() {

# respect other settings
Expand Down Expand Up @@ -194,6 +219,9 @@ upgrade_pre06 <- function(x) {
if (is.numeric(x$type)) {
x$type <- c("cbow", "sg")[x$type]
}
if (is.null(x$tolower)) {
x$tolower <- TRUE
}
return(x)
}

Expand Down
1 change: 1 addition & 0 deletions R/word2vec.R
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ wordvector <- function(x, dim = 50, type = c("cbow", "sg", "dm", "dbow"),

result$type <- type
result$min_count <- min_count
result$tolower <- tolower
result$concatenator <- meta(x, field = "concatenator", type = "object")
if (include_data) # NOTE: consider removing
result$data <- y
Expand Down
20 changes: 20 additions & 0 deletions man/perplexity.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/probability.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 21 additions & 13 deletions tests/misc/test_small.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,36 @@ wdv$weights["america",]
wdv2$weights["america",]
wdv3$weights["america",]

#dov <- textmodel_doc2vec(toks, dim = 50, type = "cbow", min_count = 5, verbose = TRUE, iter = 10)
dov <- textmodel_doc2vec(toks, dim = 50, type = "skip-gram", min_count = 5, verbose = TRUE, iter = 10)
dov <- textmodel_doc2vec(toks, dim = 100, type = "dm", min_count = 5, verbose = TRUE, iter = 10)
dov2 <- textmodel_doc2vec(toks, dim = 100, type = "dm", min_count = 5, verbose = TRUE, iter = 20)

similarity(dov, analogy(~ washington - america + france)) %>%
head()
sim <- proxyC::simil(
dov$values$doc,
dov$values$doc["3555430",, drop = FALSE]
)
tail(sort(s <- rowSums(sim)))

probability(dov, c("good")) %>%
head()
sim2 <- proxyC::simil(
dov2$values$doc,
dov2$values$doc["3555430",, drop = FALSE]
)
tail(sort(s2 <- rowSums(sim2)))

identical(s, s2)
cor(s, s2)

sim <- proxyC::simil(
dov$values$doc,
dov$values$doc["4263794",, drop = FALSE]
dov2$values$doc,
dov2$values$doc["4263794",, drop = FALSE]
)
sim <- proxyC::simil(
dov$values$doc,
dov$values$doc["3016236",, drop = FALSE]
dov2$values$doc,
dov2$values$doc["3016236",, drop = FALSE]
)
sim <- proxyC::simil(
dov$values$doc,
dov$values$doc["3555430",, drop = FALSE]
dov2$values$doc,
dov2$values$doc["3555430",, drop = FALSE]
)

tail(sort(s <- rowSums(sim)))
print(tail(toks[order(s)]), max_ntoken = -1)

Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-as.doc2vec.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ test_that("textmodel_doc2vec works", {
expect_false(dov1$normalize)
expect_equal(
names(dov1),
c("values", "dim", "concatenator", "docvars", "normalize", "call", "version")
c("values", "dim", "tolower", "concatenator", "docvars", "normalize", "call", "version")
)
expect_equal(
dim(dov1$values$word), c(5363L, 50L)
Expand Down
6 changes: 4 additions & 2 deletions tests/testthat/test-doc2vec.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ test_that("textmodel_doc2vec works", {
expect_equal(
names(dov1),
c("values", "weights", "type", "dim", "frequency", "window", "iter", "alpha",
"use_ns", "ns_size", "sample", "normalize", "min_count", "concatenator", "docvars", "call", "version")
"use_ns", "ns_size", "sample", "normalize", "min_count", "tolower",
"concatenator", "docvars", "call", "version")
)
expect_equal(
dim(dov1$values$word), c(5363L, 50L)
Expand Down Expand Up @@ -64,7 +65,8 @@ test_that("textmodel_doc2vec works", {
expect_equal(
names(dov2),
c("values", "weights", "type", "dim", "frequency", "window", "iter", "alpha",
"use_ns", "ns_size", "sample", "normalize", "min_count", "concatenator", "docvars", "call", "version")
"use_ns", "ns_size", "sample", "normalize", "min_count", "tolower",
"concatenator", "docvars", "call", "version")
)
expect_null(
dov2$values$word
Expand Down
7 changes: 4 additions & 3 deletions tests/testthat/test-lsa.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ test_that("word2vec words", {
)
expect_equal(
names(dov),
c("values", "dim", "concatenator", "docvars", "normalize", "call", "version")
c("values", "dim", "tolower", "concatenator", "docvars", "normalize",
"call", "version")
)

# docvector with model
Expand Down Expand Up @@ -78,7 +79,7 @@ test_that("word2vec words", {
)
expect_equal(
names(dov),
c("values", "dim", "concatenator", "docvars", "normalize", "call", "version")
c("values", "dim", "tolower", "concatenator", "docvars", "normalize", "call", "version")
)

# docvector with grouped data
Expand All @@ -94,7 +95,7 @@ test_that("word2vec words", {
)
expect_equal(
names(dov_gp),
c("values", "dim", "concatenator", "docvars", "normalize", "call", "version")
c("values", "dim", "tolower", "concatenator", "docvars", "normalize", "call", "version")
)
})

Expand Down
34 changes: 34 additions & 0 deletions tests/testthat/test-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,37 @@ test_that("old arguments still works", {
)

})

test_that("perplexity works", {

# infrequent words
word1 <- c("good", "nice", "excellent", "positive", "fortunate", "correct", "superior",
"bad", "nasty", "poor", "negative", "unfortunate", "wrong", "inferior")
suppressWarnings(
ppl1 <- perplexity(wov, word1, dfmt)
)
expect_gt(ppl1, 3.0)

# frequent words
word2 <- c("america", "us", "people", "government", "state", "nation", "world", "peace", "public")
suppressWarnings(
ppl2 <- perplexity(wov, word2, dfmt)
)
expect_lt(ppl2, ppl1)

# tokens_object
suppressWarnings(
ppl3 <- perplexity(wov, word2, toks)
)
expect_equal(ppl3, ppl2)

expect_error(
perplexity(wov, c("good" = 1, "bad" = -1), dfmt),
"targets must be a character vector"
)

expect_error(
perplexity(wov, word2, list),
"data must be a tokens or dfm"
)
})
6 changes: 6 additions & 0 deletions tests/testthat/test-word2vec.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ test_that("textmodel_word2vec works", {
featfreq(dfm_trim(dfm(toks), 2)),
wov1$frequency
)
expect_true(
wov1$tolower
)

expect_output(
print(wov1),
Expand Down Expand Up @@ -113,6 +116,9 @@ test_that("textmodel_word2vec works", {
featfreq(dfm_trim(dfm(toks), 2)),
wov2$frequency
)
expect_true(
wov2$tolower
)

expect_output(
print(wov2),
Expand Down