Skip to content

Commit 975fd5c

Browse files
authored
SchemaFromExistingGraphExtractor component (#355)
* Add SchemaFromExistingGraphExtractor component Parses the result from get_structured_schema and returns a GraphSchema object * Extract required properties from existing constraints * Use BaseSchemaBuilder * Test happy path * Fix mypy * Better handling of additional_* parameters by relying on the default rules implemented in the GraphSchema object * Remove unused INDEX name in example
1 parent 1401f45 commit 975fd5c

File tree

4 files changed

+342
-4
lines changed

4 files changed

+342
-4
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""This example demonstrates how to use the SchemaFromExistingGraphExtractor component
2+
to automatically extract a schema from an existing Neo4j database.
3+
"""
4+
5+
import asyncio
6+
from pprint import pprint
7+
8+
import neo4j
9+
10+
from neo4j_graphrag.experimental.components.schema import (
11+
SchemaFromExistingGraphExtractor,
12+
GraphSchema,
13+
)
14+
15+
16+
URI = "neo4j+s://demo.neo4jlabs.com"
17+
AUTH = ("recommendations", "recommendations")
18+
DATABASE = "recommendations"
19+
20+
21+
async def main() -> None:
22+
"""Run the example."""
23+
24+
with neo4j.GraphDatabase.driver(
25+
URI,
26+
auth=AUTH,
27+
) as driver:
28+
extractor = SchemaFromExistingGraphExtractor(
29+
driver,
30+
# optional:
31+
neo4j_database=DATABASE,
32+
additional_patterns=True,
33+
additional_node_types=True,
34+
additional_relationship_types=True,
35+
additional_properties=True,
36+
)
37+
schema: GraphSchema = await extractor.run()
38+
# schema.store_as_json("my_schema.json")
39+
pprint(schema.model_dump())
40+
41+
42+
if __name__ == "__main__":
43+
asyncio.run(main())

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 175 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from __future__ import annotations
1616

1717
import json
18+
19+
import neo4j
1820
import logging
1921
import warnings
2022
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence, Callable
@@ -44,6 +46,10 @@
4446
from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate
4547
from neo4j_graphrag.llm import LLMInterface
4648
from neo4j_graphrag.utils.file_handler import FileHandler, FileFormat
49+
from neo4j_graphrag.schema import get_structured_schema
50+
51+
52+
logger = logging.getLogger(__name__)
4753

4854

4955
class PropertyType(BaseModel):
@@ -306,7 +312,12 @@ def from_file(
306312
raise SchemaValidationError(str(e)) from e
307313

308314

309-
class SchemaBuilder(Component):
315+
class BaseSchemaBuilder(Component):
316+
async def run(self, *args: Any, **kwargs: Any) -> GraphSchema:
317+
raise NotImplementedError()
318+
319+
320+
class SchemaBuilder(BaseSchemaBuilder):
310321
"""
311322
A builder class for constructing GraphSchema objects from given entities,
312323
relations, and their interrelationships defined in a potential schema.
@@ -424,7 +435,7 @@ async def run(
424435
)
425436

426437

427-
class SchemaFromTextExtractor(Component):
438+
class SchemaFromTextExtractor(BaseSchemaBuilder):
428439
"""
429440
A component for constructing GraphSchema objects from the output of an LLM after
430441
automatic schema extraction from text.
@@ -621,3 +632,165 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
621632
"patterns": extracted_patterns,
622633
}
623634
)
635+
636+
637+
class SchemaFromExistingGraphExtractor(BaseSchemaBuilder):
638+
"""A class to build a GraphSchema object from an existing graph.
639+
640+
Uses the get_structured_schema function to extract existing node labels,
641+
relationship types, properties and existence constraints.
642+
643+
By default, the built schema does not allow any additional item (property,
644+
node label, relationship type or pattern).
645+
646+
Args:
647+
driver (neo4j.Driver): connection to the neo4j database.
648+
additional_properties (bool, default False): see GraphSchema
649+
additional_node_types (bool, default False): see GraphSchema
650+
additional_relationship_types (bool, default False): see GraphSchema:
651+
additional_patterns (bool, default False): see GraphSchema:
652+
neo4j_database (Optional | str): name of the neo4j database to use
653+
"""
654+
655+
def __init__(
656+
self,
657+
driver: neo4j.Driver,
658+
additional_properties: bool | None = None,
659+
additional_node_types: bool | None = None,
660+
additional_relationship_types: bool | None = None,
661+
additional_patterns: bool | None = None,
662+
neo4j_database: Optional[str] = None,
663+
) -> None:
664+
self.driver = driver
665+
self.database = neo4j_database
666+
667+
self.additional_properties = additional_properties
668+
self.additional_node_types = additional_node_types
669+
self.additional_relationship_types = additional_relationship_types
670+
self.additional_patterns = additional_patterns
671+
672+
@staticmethod
673+
def _extract_required_properties(
674+
structured_schema: dict[str, Any],
675+
) -> list[tuple[str, str]]:
676+
"""Extract a list of (node label (or rel type), property name) for which
677+
an "EXISTENCE" or "KEY" constraint is defined in the DB.
678+
679+
Args:
680+
681+
structured_schema (dict[str, Any]): the result of the `get_structured_schema()` function.
682+
683+
Returns:
684+
685+
list of tuples of (node label (or rel type), property name)
686+
687+
"""
688+
schema_metadata = structured_schema.get("metadata", {})
689+
existence_constraint = [] # list of (node label, property name)
690+
for constraint in schema_metadata.get("constraint", []):
691+
if constraint["type"] in (
692+
"NODE_PROPERTY_EXISTENCE",
693+
"NODE_KEY",
694+
"RELATIONSHIP_PROPERTY_EXISTENCE",
695+
"RELATIONSHIP_KEY",
696+
):
697+
properties = constraint["properties"]
698+
labels = constraint["labelsOrTypes"]
699+
# note: existence constraint only apply to a single property
700+
# and a single label
701+
prop = properties[0]
702+
lab = labels[0]
703+
existence_constraint.append((lab, prop))
704+
return existence_constraint
705+
706+
def _to_schema_entity_dict(
707+
self,
708+
key: str,
709+
property_dict: list[dict[str, Any]],
710+
existence_constraint: list[tuple[str, str]],
711+
) -> dict[str, Any]:
712+
entity_dict: dict[str, Any] = {
713+
"label": key,
714+
"properties": [
715+
{
716+
"name": p["property"],
717+
"type": p["type"],
718+
"required": (key, p["property"]) in existence_constraint,
719+
}
720+
for p in property_dict
721+
],
722+
}
723+
if self.additional_properties:
724+
entity_dict["additional_properties"] = self.additional_properties
725+
return entity_dict
726+
727+
async def run(self, *args: Any, **kwargs: Any) -> GraphSchema:
728+
structured_schema = get_structured_schema(self.driver, database=self.database)
729+
existence_constraint = self._extract_required_properties(structured_schema)
730+
731+
# node label with properties
732+
node_labels = set(structured_schema["node_props"].keys())
733+
node_types = [
734+
self._to_schema_entity_dict(key, properties, existence_constraint)
735+
for key, properties in structured_schema["node_props"].items()
736+
]
737+
738+
# relationships with properties
739+
rel_labels = set(structured_schema["rel_props"].keys())
740+
relationship_types = [
741+
self._to_schema_entity_dict(key, properties, existence_constraint)
742+
for key, properties in structured_schema["rel_props"].items()
743+
]
744+
745+
patterns = [
746+
(s["start"], s["type"], s["end"])
747+
for s in structured_schema["relationships"]
748+
]
749+
750+
# deal with nodes and relationships without properties
751+
for source, rel, target in patterns:
752+
if source not in node_labels:
753+
if self.additional_properties is False:
754+
logger.warning(
755+
f"SCHEMA: found node label {source} without property and additional_properties=False: this node label will always be pruned!"
756+
)
757+
node_labels.add(source)
758+
node_types.append(
759+
{
760+
"label": source,
761+
}
762+
)
763+
if target not in node_labels:
764+
if self.additional_properties is False:
765+
logger.warning(
766+
f"SCHEMA: found node label {target} without property and additional_properties=False: this node label will always be pruned!"
767+
)
768+
node_labels.add(target)
769+
node_types.append(
770+
{
771+
"label": target,
772+
}
773+
)
774+
if rel not in rel_labels:
775+
rel_labels.add(rel)
776+
relationship_types.append(
777+
{
778+
"label": rel,
779+
}
780+
)
781+
schema_dict: dict[str, Any] = {
782+
"node_types": node_types,
783+
"relationship_types": relationship_types,
784+
"patterns": patterns,
785+
}
786+
if self.additional_node_types is not None:
787+
schema_dict["additional_node_types"] = self.additional_node_types
788+
if self.additional_relationship_types is not None:
789+
schema_dict["additional_relationship_types"] = (
790+
self.additional_relationship_types
791+
)
792+
if self.additional_patterns is not None:
793+
schema_dict["additional_patterns"] = self.additional_patterns
794+
return GraphSchema.model_validate(
795+
schema_dict,
796+
)

src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
SchemaBuilder,
4545
GraphSchema,
4646
SchemaFromTextExtractor,
47+
BaseSchemaBuilder,
4748
)
4849
from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter
4950
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
@@ -178,7 +179,7 @@ def _get_run_params_for_splitter(self) -> dict[str, Any]:
178179
def _get_chunk_embedder(self) -> TextChunkEmbedder:
179180
return TextChunkEmbedder(embedder=self.get_default_embedder())
180181

181-
def _get_schema(self) -> Union[SchemaBuilder, SchemaFromTextExtractor]:
182+
def _get_schema(self) -> BaseSchemaBuilder:
182183
"""
183184
Get the appropriate schema component based on configuration.
184185
Return SchemaFromTextExtractor for automatic extraction or SchemaBuilder for manual schema.

0 commit comments

Comments
 (0)