Skip to content

Commit

Permalink
feat: improvements in default test generation (explodinggradients#1661)
Browse files Browse the repository at this point in the history
- [x] Make sure test generation runs with short docs, long docs, with
small number of docs etc
- [x] Tune default settings for the above
- [x] Relaxed filters for query creation
  • Loading branch information
shahules786 committed Nov 14, 2024
1 parent 162370f commit d162f44
Show file tree
Hide file tree
Showing 11 changed files with 243 additions and 101 deletions.
9 changes: 7 additions & 2 deletions src/ragas/testset/persona.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import random
import typing as t

import numpy as np
Expand All @@ -19,7 +18,7 @@ def default_filter(node: Node) -> bool:
node.type.name == "DOCUMENT"
and node.properties.get("summary_embedding") is not None
):
return random.random() < 0.25
return True
else:
return False

Expand Down Expand Up @@ -92,8 +91,14 @@ def generate_personas_from_kg(
"""

nodes = [node for node in kg.nodes if filter_fn(node)]
if len(nodes) == 0:
raise ValueError(
"No nodes that satisfied the given filer. Try changing the filter."
)

summaries = [node.properties.get("summary") for node in nodes]
summaries = [summary for summary in summaries if isinstance(summary, str)]
num_personas = min(num_personas, len(summaries))

embeddings = []
for node in nodes:
Expand Down
23 changes: 18 additions & 5 deletions src/ragas/testset/synthesizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import typing as t

from ragas.llms import BaseRagasLLM
from ragas.testset.graph import KnowledgeGraph
from ragas.testset.synthesizers.multi_hop import (
MultiHopAbstractQuerySynthesizer,
MultiHopSpecificQuerySynthesizer,
Expand All @@ -14,12 +15,24 @@
QueryDistribution = t.List[t.Tuple[BaseSynthesizer, float]]


def default_query_distribution(llm: BaseRagasLLM) -> QueryDistribution:
return [
(SingleHopSpecificQuerySynthesizer(llm=llm), 0.5),
(MultiHopAbstractQuerySynthesizer(llm=llm), 0.25),
(MultiHopSpecificQuerySynthesizer(llm=llm), 0.25),
def default_query_distribution(
llm: BaseRagasLLM, kg: t.Optional[KnowledgeGraph] = None
) -> QueryDistribution:
""" """
default_queries = [
SingleHopSpecificQuerySynthesizer(llm=llm),
MultiHopAbstractQuerySynthesizer(llm=llm),
MultiHopSpecificQuerySynthesizer(llm=llm),
]
if kg is not None:
available_queries = []
for query in default_queries:
if query.get_node_clusters(kg):
available_queries.append(query)
else:
available_queries = default_queries

return [(query, 1 / len(available_queries)) for query in available_queries]


__all__ = [
Expand Down
11 changes: 6 additions & 5 deletions src/ragas/testset/synthesizers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@
from ragas._analytics import TestsetGenerationEvent, track
from ragas.callbacks import new_group
from ragas.cost import TokenUsageParser
from ragas.embeddings.base import (
BaseRagasEmbeddings,
LlamaIndexEmbeddingsWrapper,
)
from ragas.embeddings.base import BaseRagasEmbeddings, LlamaIndexEmbeddingsWrapper
from ragas.executor import Executor
from ragas.llms import BaseRagasLLM, LangchainLLMWrapper, LlamaIndexLLMWrapper
from ragas.run_config import RunConfig
Expand Down Expand Up @@ -155,6 +152,7 @@ def generate_with_langchain_docs(

if not transforms:
transforms = default_transforms(
documents=list(documents),
llm=transforms_llm or self.llm,
embedding_model=transforms_embedding_model,
)
Expand Down Expand Up @@ -224,6 +222,7 @@ def generate_with_llamaindex_docs(
transforms_embedding_model
)
transforms = default_transforms(
documents=[LCDocument(page_content=doc.text) for doc in documents],
llm=llm_for_transforms,
embedding_model=embedding_model_for_transforms,
)
Expand Down Expand Up @@ -312,7 +311,9 @@ def generate(
if run_config is not None:
self.llm.set_run_config(run_config)

query_distribution = query_distribution or default_query_distribution(self.llm)
query_distribution = query_distribution or default_query_distribution(
self.llm, self.knowledge_graph
)
callbacks = callbacks or []

# dict to store any callbacks we define
Expand Down
23 changes: 14 additions & 9 deletions src/ragas/testset/synthesizers/multi_hop/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np

from ragas.prompt import PydanticPrompt
from ragas.testset.graph import KnowledgeGraph
from ragas.testset.graph import KnowledgeGraph, Node
from ragas.testset.graph_queries import get_child_nodes
from ragas.testset.persona import Persona, PersonaList
from ragas.testset.synthesizers.multi_hop.base import (
Expand Down Expand Up @@ -42,6 +42,17 @@ class MultiHopAbstractQuerySynthesizer(MultiHopQuerySynthesizer):
concept_combination_prompt: PydanticPrompt = ConceptCombinationPrompt()
theme_persona_matching_prompt: PydanticPrompt = ThemesPersonasMatchingPrompt()

def get_node_clusters(self, knowledge_graph: KnowledgeGraph) -> t.List[t.Set[Node]]:

node_clusters = knowledge_graph.find_indirect_clusters(
relationship_condition=lambda rel: (
True if rel.get_property("summary_similarity") else False
),
depth_limit=3,
)
logger.info("found %d clusters", len(node_clusters))
return node_clusters

async def _generate_scenarios(
self,
n: int,
Expand All @@ -61,18 +72,12 @@ async def _generate_scenarios(
4. Sample diverse combinations of scenarios to get n samples
"""

node_clusters = knowledge_graph.find_indirect_clusters(
relationship_condition=lambda rel: (
True if rel.get_property("summary_similarity") else False
),
depth_limit=3,
)
logger.info("found %d clusters", len(node_clusters))
node_clusters = self.get_node_clusters(knowledge_graph)
scenarios = []

if len(node_clusters) == 0:
raise ValueError(
"No clusters found in the knowledge graph. Use a different Synthesizer."
"No clusters found in the knowledge graph. Try changing the relationship condition."
)
num_sample_per_cluster = int(np.ceil(n / len(node_clusters)))

Expand Down
2 changes: 1 addition & 1 deletion src/ragas/testset/synthesizers/multi_hop/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def prepare_combinations(
valid_nodes = []
for node in nodes:
node_themes = [
theme.lower() for theme in node.get_property(property_name)
theme.lower() for theme in node.properties.get(property_name, [])
]
if node.get_property(property_name) and any(
concept.lower() in node_themes for concept in combination
Expand Down
42 changes: 27 additions & 15 deletions src/ragas/testset/synthesizers/multi_hop/specific.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np

from ragas.prompt import PydanticPrompt
from ragas.testset.graph import KnowledgeGraph
from ragas.testset.graph import KnowledgeGraph, Node
from ragas.testset.persona import Persona, PersonaList
from ragas.testset.synthesizers.multi_hop.base import (
MultiHopQuerySynthesizer,
Expand Down Expand Up @@ -38,9 +38,26 @@ class MultiHopSpecificQuerySynthesizer(MultiHopQuerySynthesizer):
"""

name: str = "multi_hop_specific_query_synthesizer"
relation_type: str = "entities_overlap"
property_name: str = "entities"
theme_persona_matching_prompt: PydanticPrompt = ThemesPersonasMatchingPrompt()
generate_query_reference_prompt: PydanticPrompt = QueryAnswerGenerationPrompt()

def get_node_clusters(self, knowledge_graph: KnowledgeGraph) -> t.List[t.Set[Node]]:

cluster_dict = knowledge_graph.find_direct_clusters(
relationship_condition=lambda rel: (
True if rel.type == self.relation_type else False
)
)
logger.info("found %d clusters", len(cluster_dict))
node_clusters = []
for key_node, list_of_nodes in cluster_dict.items():
for node in list_of_nodes:
node_clusters.append((key_node, node))

return node_clusters

async def _generate_scenarios(
self,
n: int,
Expand All @@ -61,26 +78,21 @@ async def _generate_scenarios(
4. Return the list of scenarios of length n
"""

cluster_dict = knowledge_graph.find_direct_clusters(
relationship_condition=lambda rel: (
True if rel.type == "entities_overlap" else False
node_clusters = self.get_node_clusters(knowledge_graph)

if len(node_clusters) == 0:
raise ValueError(
"No clusters found in the knowledge graph. Try changing the relationship condition."
)
)

num_sample_per_cluster = int(np.ceil(n / len(node_clusters)))

valid_relationships = [
rel
for rel in knowledge_graph.relationships
if rel.type == "entities_overlap"
if rel.type == self.relation_type
]

node_clusters = []
for key_node, list_of_nodes in cluster_dict.items():
for node in list_of_nodes:
node_clusters.append((key_node, node))

logger.info("found %d clusters", len(cluster_dict))
scenarios = []
num_sample_per_cluster = int(np.ceil(n / len(node_clusters)))

for cluster in node_clusters:
if len(scenarios) < n:
Expand All @@ -106,7 +118,7 @@ async def _generate_scenarios(
overlapped_items,
PersonaList(personas=persona_list),
persona_concepts,
property_name="entities",
property_name=self.property_name,
)
base_scenarios = self.sample_diverse_combinations(
base_scenarios, num_sample_per_cluster
Expand Down
46 changes: 35 additions & 11 deletions src/ragas/testset/synthesizers/single_hop/specific.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import logging
import typing as t
from collections import defaultdict
from dataclasses import dataclass

import numpy as np

from ragas.prompt import PydanticPrompt
from ragas.testset.graph import KnowledgeGraph
from ragas.testset.graph import KnowledgeGraph, Node
from ragas.testset.persona import Persona, PersonaList
from ragas.testset.synthesizers.base import BaseScenario
from ragas.testset.synthesizers.prompts import (
Expand Down Expand Up @@ -40,6 +41,37 @@ class SingleHopScenario(BaseScenario):
class SingleHopSpecificQuerySynthesizer(SingleHopQuerySynthesizer):
name: str = "single_hop_specifc_query_synthesizer"
theme_persona_matching_prompt: PydanticPrompt = ThemesPersonasMatchingPrompt()
property_name: str = "entities"

def get_node_clusters(self, knowledge_graph: KnowledgeGraph) -> t.List[Node]:

node_type_dict = defaultdict(int)
for node in knowledge_graph.nodes:
if (
node.type.name == "CHUNK"
and node.get_property(self.property_name) is not None
):
node_type_dict["CHUNK"] += 1
elif (
node.type.name == "DOCUMENT"
and node.get_property(self.property_name) is not None
):
node_type_dict["DOCUMENT"] += 1
else:
pass

node_filter = (
"CHUNK"
if node_type_dict["CHUNK"] > node_type_dict["DOCUMENT"]
else "DOCUMENT"
)

nodes = []
for node in knowledge_graph.nodes:
if node.type.name == node_filter:
nodes.append(node)

return nodes

async def _generate_scenarios(
self,
Expand All @@ -61,15 +93,7 @@ async def _generate_scenarios(
4. Return the list of scenarios
"""

property_name = "entities"
nodes = []
for node in knowledge_graph.nodes:
if (
node.type.name == "CHUNK"
and node.get_property(property_name) is not None
):
nodes.append(node)

nodes = self.get_node_clusters(knowledge_graph)
if len(nodes) == 0:
raise ValueError("No nodes found with the `entities` property.")
samples_per_node = int(np.ceil(n / len(nodes)))
Expand All @@ -78,7 +102,7 @@ async def _generate_scenarios(
for node in nodes:
if len(scenarios) >= n:
break
themes = node.get_property(property_name)
themes = node.properties.get(self.property_name, [""])
prompt_input = ThemesPersonasInput(themes=themes, personas=persona_list)
persona_concepts = await self.theme_persona_matching_prompt.generate(
data=prompt_input, llm=self.llm, callbacks=callbacks
Expand Down
Loading

0 comments on commit d162f44

Please sign in to comment.