Skip to content

Endpoint for training supervised graph sage #409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 255 additions & 0 deletions examples/python-runtime-V1.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from graphdatascience import GraphDataScience\n",
"from graphdatascience import __version__\n",
"\n",
"# Make sure you have installed the custom GDS Client distributed with this notebook\n",
"assert __version__ == \"1.8a1.dev1\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# From the Aura Console, get the Connection URI to your Neo4j instance and paste here\n",
"URI = \"neo4j+s://<dbid>-mlruntimedev.databases.neo4j-dev.io\"\n",
"# And paste the database password here\n",
"PASSWORD = \"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# The usual GDS client initialization\n",
"gds = GraphDataScience(URI, auth=(\"neo4j\", PASSWORD))\n",
"gds.set_database(\"neo4j\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# We will load the Cora dataset\n",
"# The progress bar is sometimes wonky; don't worry about it\n",
"try:\n",
" gds.graph.load_cora()\n",
"except:\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# The graph import is completed when this command returns a non-empty list\n",
"gds.graph.list()"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"# GNN training!\n",
"\n",
"And now for the exciting stuff!\n",
"In the next cell, you will start a GNN training job.\n",
"In actuality, it is a PyTorch-Geometric GraphSAGE model being trained.\n",
"It happens asynchronously, so it will return immediately (unless there's an unexpected error 😱).\n",
"Of course, the training does not complete instantly, so you will have to wait for it to finish.\n",
"\n",
"## Observing the training progress\n",
"\n",
"You can observe the training progress by watching the logs.\n",
"This is done in the subsequent cell.\n",
"The watching doesn't automatically stop, so you will have to stop it manually.\n",
"Once you see the message 'Training Done', you can interrupt the cell and continue.\n",
"\n",
"## Graph and training parameters\n",
"\n",
"\n",
"\n",
"\n",
"| Parameter | Default | Type | Description |\n",
"|--------------------|----------------|----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n",
"| graph_name | - | str | The name of the graph to train on. |\n",
"| model_name | - | str | The name of the model. Must be unique per database and username combination. Models cannot be cleaned up at this time. |\n",
"| feature_properties | - | List[str] | The node properties to use as model features. |\n",
"| target_property | - | str | The node property that contains the target class values. |\n",
"| node_labels | None | List[str] | The node labels to use for training. By default, all labels are used. |\n",
"| relationship_types | None | List[str] | The relationship types to use for training. By default, all types are used. |\n",
"| target_node_label | None | str | Indicates the nodes used for training. Only nodes with this label need to have the `target_property` defined. Other nodes are used for context. By default, all nodes are considered. |\n",
"| graph_sage_config | None | dict | Configuration for the GraphSAGE training. See below. |\n",
"\n",
"\n",
"## GraphSAGE parameters\n",
"\n",
"We have exposed several parameters of the PyG GraphSAGE model.\n",
"\n",
"| Parameter | Default | Description |\n",
"|-----------------|----------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n",
"| layer_config | {} | Configuration of the GraphSAGE layers. It supports `aggr`, `normalize`, `root_weight`, `project`, `bias` from [this link](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.SAGEConv.html). Additionally, you can provide message passing configuration from [this link](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.MessagePassing.html#torch_geometric.nn.conv.MessagePassing). |\n",
"| num_neighbors | [25, 10] | Sample sizes for each layer. The length of this list is the number of layers used. All numbers must be >0. |\n",
"| dropout | 0.5 | Probability of dropping out neurons during training. Must be between 0 and 1. |\n",
"| hidden_channels | 256 | The dimension of each hidden layer. Higher value means more expensive training, but higher level of representation. Must be >0. |\n",
"| learning_rate | 0.003 | The learning rate. Must be >0. |\n",
"\n",
"Please try to use any of them with any useful values.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Let's train!\n",
"job_id = gds.gnn.nodeClassification.train(\n",
" \"cora\", \"myModel\", [\"features\"], \"subject\", [\"CITES\"], target_node_label=\"Paper\", node_labels=[\"Paper\"]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# And let's follow the progress by watching the logs\n",
"gds.gnn.nodeClassification.watch_logs(job_id)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Once the train is completed, we can retrieve the training result (metrics)\n",
"train_result = gds.run_cypher(\"RETURN gds.remoteml.getTrainResult('myModel')\");"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# And display it\n",
"train_result"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"# GNN prediction!\n",
"\n",
"Wow, that was cool.\n",
"But training a model is only half the picture.\n",
"We also have to use it for something.\n",
"In this case, we will use it to predict the subject of papers in the Cora dataset.\n",
"\n",
"Again, this call is asynchronous, so it will return immediately.\n",
"Observe the progress by watching the logs.\n",
"\n",
"Once the prediction is completed, the predicted classes are added to GDS Graph Catalog (as per normal).\n",
"We can retrieve the prediction result (the predictions themselves) by streaming from the graph.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Let's trigger prediction!\n",
"job_id = gds.gnn.nodeClassification.predict(\"cora\", \"myModel\", \"myPredictions\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# And let's follow progress by watching the logs\n",
"gds.gnn.nodeClassification.watch_logs(job_id)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Now that prediction is done, let's see the predictions\n",
"cora = gds.graph.get(\"cora\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Now for some standard GDS stuff; streaming properties from the graph\n",
"predictions = gds.graph.nodeProperties.stream(\n",
" cora, node_properties=[\"features\", \"myPredictions\"], separate_property_columns=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# And displaying them\n",
"predictions"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"# And that's it!\n",
"\n",
"Thank you very much for participating in the testing.\n",
"We hope you enjoyed it.\n",
"If you've run the notebook for the first time, now's the time to experiment and changing graph, training parameters, etc.\n",
"For example, try out a heterogeneous graph problem? Or whether performance can be improved by changing some parameter? Run training jobs in parallel, on multiple databases?\n",
"If you're feeling like you're done, please reach back to the Google Document and fill in our feedback form.\n",
"\n",
"Thank you!"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
5 changes: 4 additions & 1 deletion graphdatascience/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .algo.single_mode_algo_endpoints import SingleModeAlgoEndpoints
from .call_builder import IndirectAlphaCallBuilder, IndirectBetaCallBuilder
from .gnn.gnn_endpoints import GnnEndpoints
from .graph.graph_endpoints import (
GraphAlphaEndpoints,
GraphBetaEndpoints,
Expand Down Expand Up @@ -32,7 +33,9 @@
"""


class DirectEndpoints(DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints):
class DirectEndpoints(
DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints, GnnEndpoints
):
def __init__(self, query_runner: QueryRunner, namespace: str, server_version: ServerVersion):
super().__init__(query_runner, namespace, server_version)

Expand Down
Empty file.
18 changes: 18 additions & 0 deletions graphdatascience/gnn/gnn_endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from ..caller_base import CallerBase
from ..error.illegal_attr_checker import IllegalAttrChecker
from ..error.uncallable_namespace import UncallableNamespace
from .gnn_nc_runner import GNNNodeClassificationRunner


class GNNRunner(UncallableNamespace, IllegalAttrChecker):
@property
def nodeClassification(self) -> GNNNodeClassificationRunner:
return GNNNodeClassificationRunner(
self._query_runner, f"{self._namespace}.nodeClassification", self._server_version
)


class GnnEndpoints(CallerBase):
@property
def gnn(self) -> GNNRunner:
return GNNRunner(self._query_runner, f"{self._namespace}.gnn", self._server_version)
106 changes: 106 additions & 0 deletions graphdatascience/gnn/gnn_nc_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import json
import time
from typing import Any, List

from ..error.illegal_attr_checker import IllegalAttrChecker
from ..error.uncallable_namespace import UncallableNamespace


class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker):
def make_graph_sage_config(self, graph_sage_config):
GRAPH_SAGE_DEFAULT_CONFIG = {
"layer_config": {},
"num_neighbors": [25, 10],
"dropout": 0.5,
"hidden_channels": 256,
"learning_rate": 0.003,
}
final_sage_config = GRAPH_SAGE_DEFAULT_CONFIG
if graph_sage_config:
bad_keys = []
for key in graph_sage_config:
if key not in GRAPH_SAGE_DEFAULT_CONFIG:
bad_keys.append(key)
if len(bad_keys) > 0:
raise Exception(f"Argument graph_sage_config contains invalid keys {', '.join(bad_keys)}.")

final_sage_config.update(graph_sage_config)
return final_sage_config

def watch_logs(self, job_id: str, logging_interval: int = 5):
print(f"Watching logs of job {job_id}.")
print("This needs to be interrupted manually in order to continue (for example when training is done).")

def get_logs(offset) -> "Series[Any]": # noqa: F821
return self._query_runner.run_query(
"RETURN gds.remoteml.getLogs($job_id, $offset)", params={"job_id": job_id, "offset": offset}
).squeeze()

received_logs = 0
training_done = False
while not training_done:
time.sleep(logging_interval)
for log in get_logs(offset=received_logs):
print(log)
received_logs += 1
return job_id

def train(
self,
graph_name: str,
model_name: str,
feature_properties: List[str],
target_property: str,
relationship_types: List[str],
target_node_label: str = None,
node_labels: List[str] = None,
graph_sage_config=None,
) -> str:
mlConfigMap = {
"featureProperties": feature_properties,
"targetProperty": target_property,
"job_type": "train",
"nodeProperties": feature_properties + [target_property],
"relationshipTypes": relationship_types,
"graph_sage_config": self.make_graph_sage_config(graph_sage_config),
}

if target_node_label:
mlConfigMap["targetNodeLabel"] = target_node_label
if node_labels:
mlConfigMap["nodeLabels"] = node_labels

mlTrainingConfig = json.dumps(mlConfigMap)

# token and uri will be injected by arrow_query_runner
job_id = self._query_runner.run_query(
"CALL gds.upload.graph($config) YIELD jobId",
params={
"config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name},
},
).jobId[0]

print(f"Started job with jobId={job_id}. Use `gds.gnn.nodeClassification.watch_logs` to track progress.")
return job_id

def predict(
self,
graph_name: str,
model_name: str,
mutateProperty: str,
predictedProbabilityProperty: str = None,
) -> str:
mlConfigMap = {"job_type": "predict", "mutateProperty": mutateProperty}
if predictedProbabilityProperty:
mlConfigMap["predictedProbabilityProperty"] = predictedProbabilityProperty

mlTrainingConfig = json.dumps(mlConfigMap)
job_id = self._query_runner.run_query(
"CALL gds.upload.graph($config) YIELD jobId",
params={
"config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name},
},
).jobId[0]

print(f"Started job with jobId={job_id}. Use `gds.gnn.nodeClassification.watch_logs` to track progress.")
return job_id
Loading