|
15 | 15 | from __future__ import annotations
|
16 | 16 |
|
17 | 17 | import json
|
| 18 | + |
| 19 | +import neo4j |
18 | 20 | import logging
|
19 | 21 | import warnings
|
20 | 22 | from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence, Callable
|
|
44 | 46 | from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate
|
45 | 47 | from neo4j_graphrag.llm import LLMInterface
|
46 | 48 | from neo4j_graphrag.utils.file_handler import FileHandler, FileFormat
|
| 49 | +from neo4j_graphrag.schema import get_structured_schema |
| 50 | + |
| 51 | + |
| 52 | +logger = logging.getLogger(__name__) |
47 | 53 |
|
48 | 54 |
|
49 | 55 | class PropertyType(BaseModel):
|
@@ -306,7 +312,12 @@ def from_file(
|
306 | 312 | raise SchemaValidationError(str(e)) from e
|
307 | 313 |
|
308 | 314 |
|
309 |
| -class SchemaBuilder(Component): |
| 315 | +class BaseSchemaBuilder(Component): |
| 316 | + async def run(self, *args: Any, **kwargs: Any) -> GraphSchema: |
| 317 | + raise NotImplementedError() |
| 318 | + |
| 319 | + |
| 320 | +class SchemaBuilder(BaseSchemaBuilder): |
310 | 321 | """
|
311 | 322 | A builder class for constructing GraphSchema objects from given entities,
|
312 | 323 | relations, and their interrelationships defined in a potential schema.
|
@@ -424,7 +435,7 @@ async def run(
|
424 | 435 | )
|
425 | 436 |
|
426 | 437 |
|
427 |
| -class SchemaFromTextExtractor(Component): |
| 438 | +class SchemaFromTextExtractor(BaseSchemaBuilder): |
428 | 439 | """
|
429 | 440 | A component for constructing GraphSchema objects from the output of an LLM after
|
430 | 441 | automatic schema extraction from text.
|
@@ -621,3 +632,165 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
|
621 | 632 | "patterns": extracted_patterns,
|
622 | 633 | }
|
623 | 634 | )
|
| 635 | + |
| 636 | + |
| 637 | +class SchemaFromExistingGraphExtractor(BaseSchemaBuilder): |
| 638 | + """A class to build a GraphSchema object from an existing graph. |
| 639 | +
|
| 640 | + Uses the get_structured_schema function to extract existing node labels, |
| 641 | + relationship types, properties and existence constraints. |
| 642 | +
|
| 643 | + By default, the built schema does not allow any additional item (property, |
| 644 | + node label, relationship type or pattern). |
| 645 | +
|
| 646 | + Args: |
| 647 | + driver (neo4j.Driver): connection to the neo4j database. |
| 648 | + additional_properties (bool, default False): see GraphSchema |
| 649 | + additional_node_types (bool, default False): see GraphSchema |
| 650 | + additional_relationship_types (bool, default False): see GraphSchema: |
| 651 | + additional_patterns (bool, default False): see GraphSchema: |
| 652 | + neo4j_database (Optional | str): name of the neo4j database to use |
| 653 | + """ |
| 654 | + |
| 655 | + def __init__( |
| 656 | + self, |
| 657 | + driver: neo4j.Driver, |
| 658 | + additional_properties: bool | None = None, |
| 659 | + additional_node_types: bool | None = None, |
| 660 | + additional_relationship_types: bool | None = None, |
| 661 | + additional_patterns: bool | None = None, |
| 662 | + neo4j_database: Optional[str] = None, |
| 663 | + ) -> None: |
| 664 | + self.driver = driver |
| 665 | + self.database = neo4j_database |
| 666 | + |
| 667 | + self.additional_properties = additional_properties |
| 668 | + self.additional_node_types = additional_node_types |
| 669 | + self.additional_relationship_types = additional_relationship_types |
| 670 | + self.additional_patterns = additional_patterns |
| 671 | + |
| 672 | + @staticmethod |
| 673 | + def _extract_required_properties( |
| 674 | + structured_schema: dict[str, Any], |
| 675 | + ) -> list[tuple[str, str]]: |
| 676 | + """Extract a list of (node label (or rel type), property name) for which |
| 677 | + an "EXISTENCE" or "KEY" constraint is defined in the DB. |
| 678 | +
|
| 679 | + Args: |
| 680 | +
|
| 681 | + structured_schema (dict[str, Any]): the result of the `get_structured_schema()` function. |
| 682 | +
|
| 683 | + Returns: |
| 684 | +
|
| 685 | + list of tuples of (node label (or rel type), property name) |
| 686 | +
|
| 687 | + """ |
| 688 | + schema_metadata = structured_schema.get("metadata", {}) |
| 689 | + existence_constraint = [] # list of (node label, property name) |
| 690 | + for constraint in schema_metadata.get("constraint", []): |
| 691 | + if constraint["type"] in ( |
| 692 | + "NODE_PROPERTY_EXISTENCE", |
| 693 | + "NODE_KEY", |
| 694 | + "RELATIONSHIP_PROPERTY_EXISTENCE", |
| 695 | + "RELATIONSHIP_KEY", |
| 696 | + ): |
| 697 | + properties = constraint["properties"] |
| 698 | + labels = constraint["labelsOrTypes"] |
| 699 | + # note: existence constraint only apply to a single property |
| 700 | + # and a single label |
| 701 | + prop = properties[0] |
| 702 | + lab = labels[0] |
| 703 | + existence_constraint.append((lab, prop)) |
| 704 | + return existence_constraint |
| 705 | + |
| 706 | + def _to_schema_entity_dict( |
| 707 | + self, |
| 708 | + key: str, |
| 709 | + property_dict: list[dict[str, Any]], |
| 710 | + existence_constraint: list[tuple[str, str]], |
| 711 | + ) -> dict[str, Any]: |
| 712 | + entity_dict: dict[str, Any] = { |
| 713 | + "label": key, |
| 714 | + "properties": [ |
| 715 | + { |
| 716 | + "name": p["property"], |
| 717 | + "type": p["type"], |
| 718 | + "required": (key, p["property"]) in existence_constraint, |
| 719 | + } |
| 720 | + for p in property_dict |
| 721 | + ], |
| 722 | + } |
| 723 | + if self.additional_properties: |
| 724 | + entity_dict["additional_properties"] = self.additional_properties |
| 725 | + return entity_dict |
| 726 | + |
| 727 | + async def run(self, *args: Any, **kwargs: Any) -> GraphSchema: |
| 728 | + structured_schema = get_structured_schema(self.driver, database=self.database) |
| 729 | + existence_constraint = self._extract_required_properties(structured_schema) |
| 730 | + |
| 731 | + # node label with properties |
| 732 | + node_labels = set(structured_schema["node_props"].keys()) |
| 733 | + node_types = [ |
| 734 | + self._to_schema_entity_dict(key, properties, existence_constraint) |
| 735 | + for key, properties in structured_schema["node_props"].items() |
| 736 | + ] |
| 737 | + |
| 738 | + # relationships with properties |
| 739 | + rel_labels = set(structured_schema["rel_props"].keys()) |
| 740 | + relationship_types = [ |
| 741 | + self._to_schema_entity_dict(key, properties, existence_constraint) |
| 742 | + for key, properties in structured_schema["rel_props"].items() |
| 743 | + ] |
| 744 | + |
| 745 | + patterns = [ |
| 746 | + (s["start"], s["type"], s["end"]) |
| 747 | + for s in structured_schema["relationships"] |
| 748 | + ] |
| 749 | + |
| 750 | + # deal with nodes and relationships without properties |
| 751 | + for source, rel, target in patterns: |
| 752 | + if source not in node_labels: |
| 753 | + if self.additional_properties is False: |
| 754 | + logger.warning( |
| 755 | + f"SCHEMA: found node label {source} without property and additional_properties=False: this node label will always be pruned!" |
| 756 | + ) |
| 757 | + node_labels.add(source) |
| 758 | + node_types.append( |
| 759 | + { |
| 760 | + "label": source, |
| 761 | + } |
| 762 | + ) |
| 763 | + if target not in node_labels: |
| 764 | + if self.additional_properties is False: |
| 765 | + logger.warning( |
| 766 | + f"SCHEMA: found node label {target} without property and additional_properties=False: this node label will always be pruned!" |
| 767 | + ) |
| 768 | + node_labels.add(target) |
| 769 | + node_types.append( |
| 770 | + { |
| 771 | + "label": target, |
| 772 | + } |
| 773 | + ) |
| 774 | + if rel not in rel_labels: |
| 775 | + rel_labels.add(rel) |
| 776 | + relationship_types.append( |
| 777 | + { |
| 778 | + "label": rel, |
| 779 | + } |
| 780 | + ) |
| 781 | + schema_dict: dict[str, Any] = { |
| 782 | + "node_types": node_types, |
| 783 | + "relationship_types": relationship_types, |
| 784 | + "patterns": patterns, |
| 785 | + } |
| 786 | + if self.additional_node_types is not None: |
| 787 | + schema_dict["additional_node_types"] = self.additional_node_types |
| 788 | + if self.additional_relationship_types is not None: |
| 789 | + schema_dict["additional_relationship_types"] = ( |
| 790 | + self.additional_relationship_types |
| 791 | + ) |
| 792 | + if self.additional_patterns is not None: |
| 793 | + schema_dict["additional_patterns"] = self.additional_patterns |
| 794 | + return GraphSchema.model_validate( |
| 795 | + schema_dict, |
| 796 | + ) |
0 commit comments