diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index 171136df..a2bd5119 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -1847,11 +1847,12 @@ def get_representative_docs(self, topic: int = None) -> List[str]: else: return self.representative_docs_ - @staticmethod def get_topic_tree( + self, hier_topics: pd.DataFrame, max_distance: float = None, tight_layout: bool = False, + custom_labels: Union[bool, str] = False, ) -> str: """Extract the topic tree such that it can be printed. @@ -1862,6 +1863,11 @@ def get_topic_tree( based on the Distance column in `hier_topics`. tight_layout: Whether to use a tight layout (narrow width) for easier readability if you have hundreds of topics. + custom_labels: If bool, whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. + If `str`, it uses labels from other aspects, e.g., "Aspect1". + NOTE: Custom labels are only generated for the original + un-merged topics. Returns: A tree that has the following structure when printed: @@ -1897,9 +1903,40 @@ def get_topic_tree( max_original_topic = hier_topics.Parent_ID.astype(int).min() - 1 + # Prepare tree labels to print + child_left_ids = hier_topics.Child_Left_ID.astype(int) + child_right_ids = hier_topics.Child_Right_ID.astype(int) + + # Get the new parent labels generated from `hierarchical_topics` + new_left_labels = {int(row["Child_Left_ID"]): row["Child_Left_Name"] for idx, row in hier_topics.iterrows()} + new_right_labels = {int(row["Child_Right_ID"]): row["Child_Right_Name"] for idx, row in hier_topics.iterrows()} + + if custom_labels: + left_labels = {} + if isinstance(custom_labels, str): + for topic, kws_info in self.topic_aspects_[custom_labels].items(): + label = "_".join([kw[0] for kw in kws_info[:5]]) # displaying top 5 kws + left_labels[topic] = label + elif self.custom_labels_ is not None and custom_labels: + left_labels = {topic_id: label for topic_id, label in enumerate(self.custom_labels_, -self._outliers)} + + right_labels = left_labels.copy() + + # We want to preserve the original labels from `topic_aspects_` or `custom_labels_` + # while adding in those generated from `hierarchical_topics` + new_left_labels.update(left_labels) + new_right_labels.update(right_labels) + + child_left_names = [new_left_labels[topic] for topic in child_left_ids] + child_right_names = [new_right_labels[topic] for topic in child_right_ids] + + else: + child_left_names = hier_topics.Child_Left_Name + child_right_names = hier_topics.Child_Right_Name + # Extract mapping from ID to name - topic_to_name = dict(zip(hier_topics.Child_Left_ID, hier_topics.Child_Left_Name)) - topic_to_name.update(dict(zip(hier_topics.Child_Right_ID, hier_topics.Child_Right_Name))) + topic_to_name = dict(zip(child_left_ids.astype(str), child_left_names)) + topic_to_name.update(dict(zip(child_right_ids.astype(str), child_right_names))) topic_to_name = {topic: name[:100] for topic, name in topic_to_name.items()} # Create tree