Skip to content

Commit

Permalink
minor code improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
maximtrp committed Sep 28, 2024
1 parent 17acb8b commit a2c9509
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 172 deletions.
99 changes: 50 additions & 49 deletions src/tmplot/_distance.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
__all__ = [
'get_topics_dist', 'get_topics_scatter', 'get_top_topic_words']
from typing import Union, List
__all__ = ["get_topics_dist", "get_topics_scatter", "get_top_topic_words"]
from typing import Optional, Union, List
from itertools import combinations
from pandas import DataFrame
from pandas import DataFrame, Index
import numpy as np
from scipy.special import kl_div
from scipy.spatial import distance
from sklearn.manifold import (
TSNE, Isomap, LocallyLinearEmbedding, MDS, SpectralEmbedding)
TSNE,
Isomap,
LocallyLinearEmbedding,
MDS,
SpectralEmbedding,
)
from ._helpers import calc_topics_marg_probs


Expand All @@ -28,15 +32,14 @@ def _dist_jsd(a1: np.ndarray, a2: np.ndarray):

def _dist_jef(a1: np.ndarray, a2: np.ndarray):
vals = (a1 - a2) * (np.log(a1) - np.log(a2))
vals[(vals <= 0) | ~np.isfinite(vals)] = 0.
vals[(vals <= 0) | ~np.isfinite(vals)] = 0.0
return vals.sum()


def _dist_hel(a1: np.ndarray, a2: np.ndarray):
a1[(a1 <= 0) | ~np.isfinite(a1)] = 1e-64
a2[(a2 <= 0) | ~np.isfinite(a2)] = 1e-64
hel_val = distance.euclidean(
np.sqrt(a1), np.sqrt(a2)) / np.sqrt(2)
hel_val = distance.euclidean(np.sqrt(a1), np.sqrt(a2)) / np.sqrt(2)
return hel_val


Expand All @@ -52,19 +55,18 @@ def _dist_tv(a1: np.ndarray, a2: np.ndarray):
return dist


def _dist_jac(a1: np.ndarray, a2: np.ndarray, top_words=100):
a = np.argsort(a1)[:-top_words-1:-1]
b = np.argsort(a2)[:-top_words-1:-1]
def _dist_jac(a1: np.ndarray, a2: np.ndarray, top_words=100):
a = np.argsort(a1)[: -top_words - 1 : -1]
b = np.argsort(a2)[: -top_words - 1 : -1]
j_num = np.intersect1d(a, b, assume_unique=False).size
j_den = np.union1d(a, b).size
jac_val = 1 - j_num / j_den
return jac_val


def get_topics_dist(
phi: Union[np.ndarray, DataFrame],
method: str = "sklb",
**kwargs) -> np.ndarray:
phi: Union[np.ndarray, DataFrame], method: str = "sklb", **kwargs
) -> np.ndarray:
"""Finding closest topics in models.
Parameters
Expand Down Expand Up @@ -110,16 +112,18 @@ def get_topics_dist(
for i, j in topics_pairs:
_dist_func = dist_funcs.get(method, "sklb")
topics_dists[((i, j), (j, i))] = _dist_func(
phi_copy[:, i], phi_copy[:, j], **kwargs)
phi_copy[:, i], phi_copy[:, j], **kwargs
)

return topics_dists


def get_topics_scatter(
topic_dists: np.ndarray,
theta: np.ndarray,
method: str = 'tsne',
method_kws: dict = None) -> DataFrame:
topic_dists: np.ndarray,
theta: np.ndarray,
method: str = "tsne",
method_kws: Optional[dict] = None,
) -> DataFrame:
"""Calculate topics coordinates for a scatter plot.
Parameters
Expand All @@ -146,52 +150,52 @@ def get_topics_scatter(
Topics scatter coordinates.
"""
if not method_kws:
method_kws = {'n_components': 2}
method_kws = {"n_components": 2}

if method == 'tsne':
method_kws.setdefault('init', 'pca')
method_kws.setdefault('learning_rate', 'auto')
method_kws.setdefault(
'perplexity', min(50, max(topic_dists.shape[0] // 2, 1)))
if method == "tsne":
method_kws.setdefault("init", "pca")
method_kws.setdefault("learning_rate", "auto")
method_kws.setdefault("perplexity", min(50, max(topic_dists.shape[0] // 2, 1)))
transformer = TSNE(**method_kws)

elif method == 'sem':
method_kws.setdefault('affinity', 'precomputed')
elif method == "sem":
method_kws.setdefault("affinity", "precomputed")
transformer = SpectralEmbedding(**method_kws)

elif method == 'mds':
method_kws.setdefault('dissimilarity', 'precomputed')
method_kws.setdefault('normalized_stress', 'auto')
elif method == "mds":
method_kws.setdefault("dissimilarity", "precomputed")
method_kws.setdefault("normalized_stress", "auto")
transformer = MDS(**method_kws)

elif method == 'lle':
method_kws['method'] = 'standard'
elif method == "lle":
method_kws["method"] = "standard"
transformer = LocallyLinearEmbedding(**method_kws)

elif method == 'ltsa':
method_kws['method'] = 'ltsa'
elif method == "ltsa":
method_kws["method"] = "ltsa"
transformer = LocallyLinearEmbedding(**method_kws)

elif method == 'isomap':
elif method == "isomap":
transformer = Isomap(**method_kws)

coords = transformer.fit_transform(topic_dists)

topics_xy = DataFrame(coords, columns=['x', 'y'])
topics_xy['topic'] = topics_xy.index.astype(int)
topics_xy['size'] = calc_topics_marg_probs(theta)
size_sum = topics_xy['size'].sum()
topics_xy = DataFrame(coords, columns=Index(["x", "y"]))
topics_xy["topic"] = topics_xy.index.astype(int)
topics_xy["size"] = calc_topics_marg_probs(theta)
size_sum = topics_xy["size"].sum()
if size_sum > 0:
topics_xy['size'] *= (100 / topics_xy['size'].sum())
topics_xy["size"] *= 100 / topics_xy["size"].sum()
else:
topics_xy['size'] = np.nan
topics_xy["size"] = np.nan
return topics_xy


def get_top_topic_words(
phi: DataFrame,
words_num: int = 20,
topics_idx: Union[List[int], np.ndarray] = None) -> DataFrame:
phi: DataFrame,
words_num: int = 20,
topics_idx: Optional[Union[List[int], np.ndarray]] = None,
) -> DataFrame:
"""Select top topic words from a fitted model.
Parameters
Expand All @@ -209,9 +213,6 @@ def get_top_topic_words(
DataFrame
Words with highest probabilities in all (or selected) topics.
"""
return phi.loc[:, topics_idx or phi.columns]\
.apply(
lambda x: x
.sort_values(ascending=False)
.head(words_num).index, axis=0
return phi.loc[:, topics_idx or phi.columns].apply(
lambda x: x.sort_values(ascending=False).head(words_num).index, axis=0
)
50 changes: 29 additions & 21 deletions src/tmplot/_stability.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
__all__ = ['get_closest_topics', 'get_stable_topics']
__all__ = ["get_closest_topics", "get_stable_topics"]
from typing import List, Tuple, Any
import numpy as np
import tqdm
from ._distance import _dist_klb, _dist_sklb, _dist_jsd, _dist_jef, _dist_hel, \
_dist_bhat, _dist_jac, _dist_tv
from ._distance import (
_dist_klb,
_dist_sklb,
_dist_jsd,
_dist_jef,
_dist_hel,
_dist_bhat,
_dist_jac,
_dist_tv,
)
from ._helpers import get_phi

dist_funcs = {
Expand All @@ -19,11 +27,12 @@


def get_closest_topics(
models: List[Any],
ref: int = 0,
method: str = "sklb",
top_words: int = 100,
verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
models: List[Any],
ref: int = 0,
method: str = "sklb",
top_words: int = 100,
verbose: bool = True,
) -> Tuple[np.ndarray, np.ndarray]:
"""Finding closest topics in models.
Parameters
Expand Down Expand Up @@ -93,7 +102,6 @@ def enum_func(x):

# Iterating over all models
for mid, model in enum_func(models):

# Current model is equal to reference model, skipping
if mid == ref:
continue
Expand All @@ -105,7 +113,8 @@ def enum_func(x):
for t_ref in range(topics_num):
for t in range(topics_num):
all_vs_all_dists[t_ref, t] = dist_func(
model_ref_phi.iloc[:, t_ref], get_phi(model).iloc[:, t])
model_ref_phi.iloc[:, t_ref], get_phi(model).iloc[:, t]
)

# Creating two arrays for the closest topics ids and distance values
if method == "jac":
Expand All @@ -119,14 +128,15 @@ def enum_func(x):


def get_stable_topics(
closest_topics: np.ndarray,
dist: np.ndarray,
norm: bool = True,
inverse: bool = True,
inverse_factor: float = 1.0,
ref: int = 0,
thres: float = 0.9,
thres_models: int = 2) -> Tuple[np.ndarray, np.ndarray]:
closest_topics: np.ndarray,
dist: np.ndarray,
norm: bool = True,
inverse: bool = True,
inverse_factor: float = 1.0,
ref: int = 0,
thres: float = 0.9,
thres_models: int = 2,
) -> Tuple[np.ndarray, np.ndarray]:
"""Finding stable topics in models.
Parameters
Expand Down Expand Up @@ -179,7 +189,5 @@ def get_stable_topics(
dist_arr = np.asarray(dist)
dist_ready = dist_arr / dist_arr.max() if norm else dist_arr.copy()
dist_ready = inverse_factor - dist_ready if inverse else dist_ready
mask = (
np.sum(np.delete(dist_ready, ref, axis=1) >= thres, axis=1)
>= thres_models)
mask = np.sum(np.delete(dist_ready, ref, axis=1) >= thres, axis=1) >= thres_models
return closest_topics[mask], dist_ready[mask]
Loading

0 comments on commit a2c9509

Please sign in to comment.