Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
default_stages: [commit, push]
default_stages: [pre-commit, pre-push]

ci:
autofix_commit_msg: |
Expand All @@ -14,30 +14,30 @@ ci:

repos:
- repo: https://github.com/pycqa/isort
rev: 5.10.1
rev: 7.0.0
hooks:
- id: isort
name: isort (python)
args: ["--profile", "black", "--filter-files", "--skip __init__.py"]

- repo: https://github.com/asottile/add-trailing-comma
rev: v2.2.3
rev: v4.0.0
hooks:
- id: add-trailing-comma

- repo: https://github.com/myint/docformatter
rev: v1.3.1
- repo: https://github.com/PyCQA/docformatter
rev: v1.7.7
hooks:
- id: docformatter
args: [--in-place]

- repo: https://github.com/psf/black
rev: 22.3.0
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 25.12.0
hooks:
- id: black

- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
rev: 7.3.0
hooks:
- id: flake8
args: [--config, .flake8]
27 changes: 17 additions & 10 deletions mess.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
from tqdm import trange

from tweetopic._doc import init_doc_words
from tweetopic.bayesian.dmm import (BayesianDMM, posterior_predictive,
predict_doc, sparse_multinomial_logpdf,
symmetric_dirichlet_logpdf,
symmetric_dirichlet_multinomial_logpdf)
from tweetopic.bayesian.dmm import (
BayesianDMM,
posterior_predictive,
predict_doc,
sparse_multinomial_logpdf,
symmetric_dirichlet_logpdf,
symmetric_dirichlet_multinomial_logpdf,
)
from tweetopic.bayesian.sampling import batch_data, sample_nuts
from tweetopic.func import spread

Expand Down Expand Up @@ -58,23 +62,26 @@ def logprior_fn(params):

def loglikelihood_fn(params, data):
doc_likelihood = jax.vmap(
partial(sparse_multinomial_logpdf, component=params["component"])
partial(sparse_multinomial_logpdf, component=params["component"]),
)
return jnp.sum(
doc_likelihood(
unique_words=data["doc_unique_words"],
unique_word_counts=data["doc_unique_word_counts"],
)
),
)


logdensity_fn(position)

logdensity_fn = lambda params: logprior_fn(params) + loglikelihood_fn(
params, data
params,
data,
)
grad_estimator = blackjax.sgmcmc.gradients.grad_estimator(
logprior_fn, loglikelihood_fn, data_size=n_documents
logprior_fn,
loglikelihood_fn,
data_size=n_documents,
)
rng_key = jax.random.PRNGKey(0)
batch_key, warmup_key, sampling_key = jax.random.split(rng_key, 3)
Expand All @@ -88,8 +95,8 @@ def loglikelihood_fn(params, data):
)
position = dict(
component=jnp.array(
transform(stats.dirichlet.mean(alpha=np.full(n_features, alpha)))
)
transform(stats.dirichlet.mean(alpha=np.full(n_features, alpha))),
),
)

samples, states = sample_nuts(position, logdensity_fn)
Expand Down
43 changes: 32 additions & 11 deletions tweetopic/_btm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Module for utility functions for fitting BTMs"""
"""Module for utility functions for fitting BTMs."""

import random
from typing import Dict, Tuple, TypeVar
Expand All @@ -12,7 +12,8 @@

@njit
def doc_unique_biterms(
doc_unique_words: np.ndarray, doc_unique_word_counts: np.ndarray
doc_unique_words: np.ndarray,
doc_unique_word_counts: np.ndarray,
) -> Dict[Tuple[int, int], int]:
(n_max_unique_words,) = doc_unique_words.shape
biterm_counts = dict()
Expand Down Expand Up @@ -43,7 +44,7 @@ def doc_unique_biterms(

@njit
def nb_add_counter(dest: Dict[T, int], source: Dict[T, int]):
"""Adds one counter dict to another in place with Numba"""
"""Adds one counter dict to another in place with Numba."""
for key in source:
if key in dest:
dest[key] += source[key]
Expand All @@ -53,25 +54,28 @@ def nb_add_counter(dest: Dict[T, int], source: Dict[T, int]):

@njit
def corpus_unique_biterms(
doc_unique_words: np.ndarray, doc_unique_word_counts: np.ndarray
doc_unique_words: np.ndarray,
doc_unique_word_counts: np.ndarray,
) -> Dict[Tuple[int, int], int]:
n_documents, _ = doc_unique_words.shape
biterm_counts = doc_unique_biterms(
doc_unique_words[0], doc_unique_word_counts[0]
doc_unique_words[0],
doc_unique_word_counts[0],
)
for i_doc in range(1, n_documents):
doc_unique_words_i = doc_unique_words[i_doc]
doc_unique_word_counts_i = doc_unique_word_counts[i_doc]
doc_biterms = doc_unique_biterms(
doc_unique_words_i, doc_unique_word_counts_i
doc_unique_words_i,
doc_unique_word_counts_i,
)
nb_add_counter(biterm_counts, doc_biterms)
return biterm_counts


@njit
def compute_biterm_set(
biterm_counts: Dict[Tuple[int, int], int]
biterm_counts: Dict[Tuple[int, int], int],
) -> np.ndarray:
return np.array(list(biterm_counts.keys()))

Expand Down Expand Up @@ -116,7 +120,12 @@ def add_biterm(
topic_biterm_count: np.ndarray,
) -> None:
add_remove_biterm(
True, i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
True,
i_biterm,
i_topic,
biterms,
topic_word_count,
topic_biterm_count,
)


Expand All @@ -129,7 +138,12 @@ def remove_biterm(
topic_biterm_count: np.ndarray,
) -> None:
add_remove_biterm(
False, i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
False,
i_biterm,
i_topic,
biterms,
topic_word_count,
topic_biterm_count,
)


Expand All @@ -147,7 +161,11 @@ def init_components(
i_topic = random.randint(0, n_components - 1)
biterm_topic_assignments[i_biterm] = i_topic
add_biterm(
i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
i_biterm,
i_topic,
biterms,
topic_word_count,
topic_biterm_count,
)
return biterm_topic_assignments, topic_word_count, topic_biterm_count

Expand Down Expand Up @@ -448,7 +466,10 @@ def predict_docs(
)
biterms = doc_unique_biterms(words, word_counts)
prob_topic_given_document(
pred, biterms, topic_distribution, topic_word_distribution
pred,
biterms,
topic_distribution,
topic_word_distribution,
)
predictions[i_doc, :] = pred
return predictions
Expand Down
4 changes: 3 additions & 1 deletion tweetopic/_dmm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Module containing tools for fitting a Dirichlet Multinomial Mixture Model."""
"""Module containing tools for fitting a Dirichlet Multinomial Mixture
Model."""

from __future__ import annotations

from math import exp, log
Expand Down
2 changes: 1 addition & 1 deletion tweetopic/_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def init_doc_words(
n_docs, _ = doc_term_matrix.shape
doc_unique_words = np.zeros((n_docs, max_unique_words)).astype(np.uint32)
doc_unique_word_counts = np.zeros((n_docs, max_unique_words)).astype(
np.uint32
np.uint32,
)
for i_doc in range(n_docs):
unique_words = doc_term_matrix[i_doc].rows[0] # type: ignore
Expand Down
18 changes: 11 additions & 7 deletions tweetopic/btm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@
import sklearn
from numpy.typing import ArrayLike

from tweetopic._btm import (compute_biterm_set, corpus_unique_biterms,
fit_model, predict_docs)
from tweetopic._btm import (
compute_biterm_set,
corpus_unique_biterms,
fit_model,
predict_docs,
)
from tweetopic._doc import init_doc_words
from tweetopic.exceptions import NotFittedException
from tweetopic.utils import set_numba_seed


class BTM(sklearn.base.TransformerMixin, sklearn.base.BaseEstimator):
"""Implementation of the Biterm Topic Model with Gibbs Sampling
solver.
"""Implementation of the Biterm Topic Model with Gibbs Sampling solver.

Parameters
----------
Expand Down Expand Up @@ -144,7 +147,9 @@ def fit(self, X: Union[spr.spmatrix, ArrayLike], y: None = None):
X.tolil(),
max_unique_words=max_unique_words,
)
biterms = corpus_unique_biterms(doc_unique_words, doc_unique_word_counts)
biterms = corpus_unique_biterms(
doc_unique_words, doc_unique_word_counts
)
biterm_set = compute_biterm_set(biterms)
self.topic_distribution, self.components_ = fit_model(
n_iter=self.n_iterations,
Expand All @@ -159,8 +164,7 @@ def fit(self, X: Union[spr.spmatrix, ArrayLike], y: None = None):
# TODO: Something goes terribly wrong here, fix this

def transform(self, X: Union[spr.spmatrix, ArrayLike]) -> np.ndarray:
"""Predicts probabilities for each document belonging to each
topic.
"""Predicts probabilities for each document belonging to each topic.

Parameters
----------
Expand Down
5 changes: 3 additions & 2 deletions tweetopic/func.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Utility functions for use in the library."""

from functools import wraps
from typing import Callable


def spread(fn: Callable):
"""Creates a new function from the given function so that it takes one
dict (PyTree) and spreads the arguments."""
"""Creates a new function from the given function so that it takes one dict
(PyTree) and spreads the arguments."""

@wraps(fn)
def inner(kwargs):
Expand Down
Loading