From a2c950920137d8a52d8da5e41b95ee710cfff3ef Mon Sep 17 00:00:00 2001 From: Maksim Terpilovskii Date: Sat, 28 Sep 2024 14:15:06 +0200 Subject: [PATCH] minor code improvements --- src/tmplot/_distance.py | 99 +++++++++--------- src/tmplot/_stability.py | 50 ++++++---- src/tmplot/_vis.py | 210 ++++++++++++++++++++------------------- 3 files changed, 187 insertions(+), 172 deletions(-) diff --git a/src/tmplot/_distance.py b/src/tmplot/_distance.py index a2b55d8..71f31ab 100644 --- a/src/tmplot/_distance.py +++ b/src/tmplot/_distance.py @@ -1,13 +1,17 @@ -__all__ = [ - 'get_topics_dist', 'get_topics_scatter', 'get_top_topic_words'] -from typing import Union, List +__all__ = ["get_topics_dist", "get_topics_scatter", "get_top_topic_words"] +from typing import Optional, Union, List from itertools import combinations -from pandas import DataFrame +from pandas import DataFrame, Index import numpy as np from scipy.special import kl_div from scipy.spatial import distance from sklearn.manifold import ( - TSNE, Isomap, LocallyLinearEmbedding, MDS, SpectralEmbedding) + TSNE, + Isomap, + LocallyLinearEmbedding, + MDS, + SpectralEmbedding, +) from ._helpers import calc_topics_marg_probs @@ -28,15 +32,14 @@ def _dist_jsd(a1: np.ndarray, a2: np.ndarray): def _dist_jef(a1: np.ndarray, a2: np.ndarray): vals = (a1 - a2) * (np.log(a1) - np.log(a2)) - vals[(vals <= 0) | ~np.isfinite(vals)] = 0. + vals[(vals <= 0) | ~np.isfinite(vals)] = 0.0 return vals.sum() def _dist_hel(a1: np.ndarray, a2: np.ndarray): a1[(a1 <= 0) | ~np.isfinite(a1)] = 1e-64 a2[(a2 <= 0) | ~np.isfinite(a2)] = 1e-64 - hel_val = distance.euclidean( - np.sqrt(a1), np.sqrt(a2)) / np.sqrt(2) + hel_val = distance.euclidean(np.sqrt(a1), np.sqrt(a2)) / np.sqrt(2) return hel_val @@ -52,9 +55,9 @@ def _dist_tv(a1: np.ndarray, a2: np.ndarray): return dist -def _dist_jac(a1: np.ndarray, a2: np.ndarray, top_words=100): - a = np.argsort(a1)[:-top_words-1:-1] - b = np.argsort(a2)[:-top_words-1:-1] +def _dist_jac(a1: np.ndarray, a2: np.ndarray, top_words=100): + a = np.argsort(a1)[: -top_words - 1 : -1] + b = np.argsort(a2)[: -top_words - 1 : -1] j_num = np.intersect1d(a, b, assume_unique=False).size j_den = np.union1d(a, b).size jac_val = 1 - j_num / j_den @@ -62,9 +65,8 @@ def _dist_jac(a1: np.ndarray, a2: np.ndarray, top_words=100): def get_topics_dist( - phi: Union[np.ndarray, DataFrame], - method: str = "sklb", - **kwargs) -> np.ndarray: + phi: Union[np.ndarray, DataFrame], method: str = "sklb", **kwargs +) -> np.ndarray: """Finding closest topics in models. Parameters @@ -110,16 +112,18 @@ def get_topics_dist( for i, j in topics_pairs: _dist_func = dist_funcs.get(method, "sklb") topics_dists[((i, j), (j, i))] = _dist_func( - phi_copy[:, i], phi_copy[:, j], **kwargs) + phi_copy[:, i], phi_copy[:, j], **kwargs + ) return topics_dists def get_topics_scatter( - topic_dists: np.ndarray, - theta: np.ndarray, - method: str = 'tsne', - method_kws: dict = None) -> DataFrame: + topic_dists: np.ndarray, + theta: np.ndarray, + method: str = "tsne", + method_kws: Optional[dict] = None, +) -> DataFrame: """Calculate topics coordinates for a scatter plot. Parameters @@ -146,52 +150,52 @@ def get_topics_scatter( Topics scatter coordinates. """ if not method_kws: - method_kws = {'n_components': 2} + method_kws = {"n_components": 2} - if method == 'tsne': - method_kws.setdefault('init', 'pca') - method_kws.setdefault('learning_rate', 'auto') - method_kws.setdefault( - 'perplexity', min(50, max(topic_dists.shape[0] // 2, 1))) + if method == "tsne": + method_kws.setdefault("init", "pca") + method_kws.setdefault("learning_rate", "auto") + method_kws.setdefault("perplexity", min(50, max(topic_dists.shape[0] // 2, 1))) transformer = TSNE(**method_kws) - elif method == 'sem': - method_kws.setdefault('affinity', 'precomputed') + elif method == "sem": + method_kws.setdefault("affinity", "precomputed") transformer = SpectralEmbedding(**method_kws) - elif method == 'mds': - method_kws.setdefault('dissimilarity', 'precomputed') - method_kws.setdefault('normalized_stress', 'auto') + elif method == "mds": + method_kws.setdefault("dissimilarity", "precomputed") + method_kws.setdefault("normalized_stress", "auto") transformer = MDS(**method_kws) - elif method == 'lle': - method_kws['method'] = 'standard' + elif method == "lle": + method_kws["method"] = "standard" transformer = LocallyLinearEmbedding(**method_kws) - elif method == 'ltsa': - method_kws['method'] = 'ltsa' + elif method == "ltsa": + method_kws["method"] = "ltsa" transformer = LocallyLinearEmbedding(**method_kws) - elif method == 'isomap': + elif method == "isomap": transformer = Isomap(**method_kws) coords = transformer.fit_transform(topic_dists) - topics_xy = DataFrame(coords, columns=['x', 'y']) - topics_xy['topic'] = topics_xy.index.astype(int) - topics_xy['size'] = calc_topics_marg_probs(theta) - size_sum = topics_xy['size'].sum() + topics_xy = DataFrame(coords, columns=Index(["x", "y"])) + topics_xy["topic"] = topics_xy.index.astype(int) + topics_xy["size"] = calc_topics_marg_probs(theta) + size_sum = topics_xy["size"].sum() if size_sum > 0: - topics_xy['size'] *= (100 / topics_xy['size'].sum()) + topics_xy["size"] *= 100 / topics_xy["size"].sum() else: - topics_xy['size'] = np.nan + topics_xy["size"] = np.nan return topics_xy def get_top_topic_words( - phi: DataFrame, - words_num: int = 20, - topics_idx: Union[List[int], np.ndarray] = None) -> DataFrame: + phi: DataFrame, + words_num: int = 20, + topics_idx: Optional[Union[List[int], np.ndarray]] = None, +) -> DataFrame: """Select top topic words from a fitted model. Parameters @@ -209,9 +213,6 @@ def get_top_topic_words( DataFrame Words with highest probabilities in all (or selected) topics. """ - return phi.loc[:, topics_idx or phi.columns]\ - .apply( - lambda x: x - .sort_values(ascending=False) - .head(words_num).index, axis=0 + return phi.loc[:, topics_idx or phi.columns].apply( + lambda x: x.sort_values(ascending=False).head(words_num).index, axis=0 ) diff --git a/src/tmplot/_stability.py b/src/tmplot/_stability.py index 3d34082..e165e50 100644 --- a/src/tmplot/_stability.py +++ b/src/tmplot/_stability.py @@ -1,9 +1,17 @@ -__all__ = ['get_closest_topics', 'get_stable_topics'] +__all__ = ["get_closest_topics", "get_stable_topics"] from typing import List, Tuple, Any import numpy as np import tqdm -from ._distance import _dist_klb, _dist_sklb, _dist_jsd, _dist_jef, _dist_hel, \ - _dist_bhat, _dist_jac, _dist_tv +from ._distance import ( + _dist_klb, + _dist_sklb, + _dist_jsd, + _dist_jef, + _dist_hel, + _dist_bhat, + _dist_jac, + _dist_tv, +) from ._helpers import get_phi dist_funcs = { @@ -19,11 +27,12 @@ def get_closest_topics( - models: List[Any], - ref: int = 0, - method: str = "sklb", - top_words: int = 100, - verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]: + models: List[Any], + ref: int = 0, + method: str = "sklb", + top_words: int = 100, + verbose: bool = True, +) -> Tuple[np.ndarray, np.ndarray]: """Finding closest topics in models. Parameters @@ -93,7 +102,6 @@ def enum_func(x): # Iterating over all models for mid, model in enum_func(models): - # Current model is equal to reference model, skipping if mid == ref: continue @@ -105,7 +113,8 @@ def enum_func(x): for t_ref in range(topics_num): for t in range(topics_num): all_vs_all_dists[t_ref, t] = dist_func( - model_ref_phi.iloc[:, t_ref], get_phi(model).iloc[:, t]) + model_ref_phi.iloc[:, t_ref], get_phi(model).iloc[:, t] + ) # Creating two arrays for the closest topics ids and distance values if method == "jac": @@ -119,14 +128,15 @@ def enum_func(x): def get_stable_topics( - closest_topics: np.ndarray, - dist: np.ndarray, - norm: bool = True, - inverse: bool = True, - inverse_factor: float = 1.0, - ref: int = 0, - thres: float = 0.9, - thres_models: int = 2) -> Tuple[np.ndarray, np.ndarray]: + closest_topics: np.ndarray, + dist: np.ndarray, + norm: bool = True, + inverse: bool = True, + inverse_factor: float = 1.0, + ref: int = 0, + thres: float = 0.9, + thres_models: int = 2, +) -> Tuple[np.ndarray, np.ndarray]: """Finding stable topics in models. Parameters @@ -179,7 +189,5 @@ def get_stable_topics( dist_arr = np.asarray(dist) dist_ready = dist_arr / dist_arr.max() if norm else dist_arr.copy() dist_ready = inverse_factor - dist_ready if inverse else dist_ready - mask = ( - np.sum(np.delete(dist_ready, ref, axis=1) >= thres, axis=1) - >= thres_models) + mask = np.sum(np.delete(dist_ready, ref, axis=1) >= thres, axis=1) >= thres_models return closest_topics[mask], dist_ready[mask] diff --git a/src/tmplot/_vis.py b/src/tmplot/_vis.py index a17d3cd..d3d02bb 100644 --- a/src/tmplot/_vis.py +++ b/src/tmplot/_vis.py @@ -1,34 +1,46 @@ # TODO: heatmap of docs in topics # TODO: topic dynamics in time # TODO: word cloud -__all__ = [ - 'plot_scatter_topics', 'plot_terms', 'plot_docs'] -from typing import Union, Sequence +__all__ = ["plot_scatter_topics", "plot_terms", "plot_docs"] +from typing import Optional, Union, Sequence +from IPython.display import HTML from pandas import DataFrame, option_context from numpy import ndarray from altair import ( - AxisConfig, Chart, X, Y, Size, Color, value, Text, Scale, Legend) + AxisConfig, + Chart, + X, + Y, + LayerChart, + Size, + Color, + value, + Text, + Scale, + Legend, +) def plot_scatter_topics( - topics_coords: Union[ndarray, DataFrame], - x_col: str = "x", - y_col: str = "y", - topic: int = None, - size_col: str = None, - label_col: str = None, - color_col: str = None, - topic_col: str = None, - font_size: int = 13, - x_kws: dict = None, - y_kws: dict = None, - chart_kws: dict = None, - circle_kws: dict = None, - circle_enc_kws: dict = None, - text_kws: dict = None, - text_enc_kws: dict = None, - size_kws: dict = None, - color_kws: dict = None) -> Chart: + topics_coords: Union[ndarray, DataFrame], + x_col: str = "x", + y_col: str = "y", + topic: int = None, + size_col: str = None, + label_col: str = None, + color_col: str = None, + topic_col: str = None, + font_size: int = 13, + x_kws: dict = None, + y_kws: dict = None, + chart_kws: dict = None, + circle_kws: dict = None, + circle_enc_kws: dict = None, + text_kws: dict = None, + text_enc_kws: dict = None, + size_kws: dict = None, + color_kws: dict = None, +) -> LayerChart: """Topics scatter plot in 2D. Parameters @@ -83,18 +95,19 @@ def plot_scatter_topics( chart_kws = {} if not x_kws: - x_kws = {'shorthand': x_col, 'axis': None} + x_kws = {"shorthand": x_col, "axis": None} if not y_kws: - y_kws = {'shorthand': y_col, 'axis': None} + y_kws = {"shorthand": y_col, "axis": None} if not circle_kws: - circle_kws = {"opacity": 0.33, "stroke": 'black', "strokeWidth": 1} + circle_kws = {"opacity": 0.33, "stroke": "black", "strokeWidth": 1} if not size_kws: size_kws = { - 'title': 'Marginal topic distribution', - 'scale': Scale(range=[0, 3000])} + "title": "Marginal topic distribution", + "scale": Scale(range=[0, 3000]), + } if not circle_enc_kws: circle_enc_kws = { @@ -102,21 +115,24 @@ def plot_scatter_topics( "y": Y(**y_kws), "size": Size(size_col, **size_kws) if size_col and not topics_coords[size_col].isna().any() - else value(500) + else value(500), } if not text_kws: text_kws = {"align": "center", "baseline": "middle"} if not color_kws: - color_kws = {}\ - if topic is None\ - else {'condition': { - "test": f"datum['topic'] == {topic}", "value": "red"}} - - data = DataFrame(topics_coords, columns=[x_col, y_col])\ - if isinstance(topics_coords, ndarray)\ + color_kws = ( + {} + if topic is None + else {"condition": {"test": f"datum['topic'] == {topic}", "value": "red"}} + ) + + data = ( + DataFrame(topics_coords, columns=[x_col, y_col]) + if isinstance(topics_coords, ndarray) else topics_coords.copy() + ) if not topic_col: topic_col = "topic" @@ -127,7 +143,8 @@ def plot_scatter_topics( "x": X(**x_kws), "y": Y(**y_kws), "text": Text(topic_col), - "size": value(font_size)} + "size": value(font_size), + } # Tooltips initialization tooltips = [] @@ -137,57 +154,47 @@ def plot_scatter_topics( tooltips.append(size_col) if tooltips: - circle_enc_kws.update({'tooltip': tooltips}) - text_enc_kws.update({'tooltip': tooltips}) + circle_enc_kws.update({"tooltip": tooltips}) + text_enc_kws.update({"tooltip": tooltips}) if color_kws: - circle_enc_kws.update({'color': Color(**color_kws)}) + circle_enc_kws.update({"color": Color(**color_kws)}) base = Chart(data, **chart_kws) - rule = base\ - .mark_rule()\ - .encode( - y='average(y)', - color=value('gray'), - size=value(0.2)) - - rule2 = base\ - .mark_rule()\ - .encode( - x='average(x)', - color=value('gray'), - size=value(0.2)) - - points = base\ - .mark_circle(**circle_kws)\ - .encode(**circle_enc_kws) - - text = base\ - .mark_text(**text_kws)\ - .encode(**text_enc_kws) - - return (rule + rule2 + points + text)\ - .configure_axis(labelFontSize=font_size, titleFontSize=font_size, grid=False)\ - .configure(axis=AxisConfig(disable=True))\ - .configure_view(stroke='transparent', strokeWidth=0)\ + rule = base.mark_rule().encode(y="average(y)", color=value("gray"), size=value(0.2)) + + rule2 = base.mark_rule().encode( + x="average(x)", color=value("gray"), size=value(0.2) + ) + + points = base.mark_circle(**circle_kws).encode(**circle_enc_kws) + + text = base.mark_text(**text_kws).encode(**text_enc_kws) + + return ( + (rule + rule2 + points + text) + .configure_axis(labelFontSize=font_size, titleFontSize=font_size, grid=False) + .configure(axis=AxisConfig(disable=True)) + .configure_view(stroke="transparent", strokeWidth=0) .configure_legend( - orient='bottom', - labelFontSize=font_size, - titleFontSize=font_size) + orient="bottom", labelFontSize=font_size, titleFontSize=font_size + ) + ) def plot_terms( - terms_probs: DataFrame, - x_col: str = 'Probability', - y_col: str = 'Terms', - color_col: str = 'Type', - font_size: int = 13, - chart_kws: dict = None, - bar_kws: dict = None, - x_kws: dict = None, - y_kws: dict = None, - color_kws: dict = None) -> Chart: + terms_probs: DataFrame, + x_col: str = "Probability", + y_col: str = "Terms", + color_col: str = "Type", + font_size: int = 13, + chart_kws: Optional[dict] = None, + bar_kws: Optional[dict] = None, + x_kws: Optional[dict] = None, + y_kws: Optional[dict] = None, + color_kws: Optional[dict] = None, +) -> Chart: """Plot words conditional and marginal probabilities. Parameters @@ -219,37 +226,36 @@ def plot_terms( Terms probabilities chart. """ if not x_kws: - x_kws = {'stack': None} + x_kws = {"stack": None} if not y_kws: - y_kws = {'sort': None, 'title': None} + y_kws = {"sort": None, "title": None} if not color_kws: color_kws = { - 'shorthand': color_col, - 'legend': Legend(orient='bottom'), - 'scale': Scale(scheme='category20') + "shorthand": color_col, + "legend": Legend(orient="bottom"), + "scale": Scale(scheme="category20"), } if not chart_kws: chart_kws = {} if not bar_kws: bar_kws = {} - return Chart(data=terms_probs, **chart_kws)\ - .mark_bar(**bar_kws)\ - .encode( - x=X(x_col, **x_kws), - y=Y(y_col, **y_kws), - color=Color(**color_kws) - )\ - .configure_axis(labelFontSize=font_size, titleFontSize=font_size)\ + return ( + Chart(data=terms_probs, **chart_kws) + .mark_bar(**bar_kws) + .encode(x=X(x_col, **x_kws), y=Y(y_col, **y_kws), color=Color(**color_kws)) + .configure_axis(labelFontSize=font_size, titleFontSize=font_size) .configure_legend( - labelFontSize=font_size, titleFontSize=font_size, - columns=1, labelLimit=250) + labelFontSize=font_size, titleFontSize=font_size, columns=1, labelLimit=250 + ) + ) def plot_docs( - docs: Union[Sequence[str], DataFrame], - styles: str = None, - html_kws: dict = None) -> DataFrame: + docs: Union[Sequence[str], DataFrame], + styles: Optional[str] = None, + html_kws: Optional[dict] = None, +) -> HTML: """Documents plotting functionality for report interface. Parameters @@ -267,11 +273,11 @@ def plot_docs( ipywidgets.HTML Topic documents. """ - from IPython.display import HTML - if styles is None: - styles = '' + styles = ( + "" + ) if html_kws is None: # html_kws = {'classes': 'plot'} html_kws = {} @@ -279,8 +285,8 @@ def plot_docs( if isinstance(docs, DataFrame): df_docs = docs.copy() else: - df_docs = DataFrame({'docs': docs}) + df_docs = DataFrame({"docs": docs}) - with option_context('display.max_colwidth', 0): + with option_context("display.max_colwidth", 0): # df_docs.style.set_properties(**{'text-align': 'center'}) return HTML(styles + df_docs.to_html(**html_kws))