diff --git a/examples/python-runtime-V1.ipynb b/examples/python-runtime-V1.ipynb new file mode 100644 index 000000000..5912ad396 --- /dev/null +++ b/examples/python-runtime-V1.ipynb @@ -0,0 +1,255 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from graphdatascience import GraphDataScience\n", + "from graphdatascience import __version__\n", + "\n", + "# Make sure you have installed the custom GDS Client distributed with this notebook\n", + "assert __version__ == \"1.8a1.dev1\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# From the Aura Console, get the Connection URI to your Neo4j instance and paste here\n", + "URI = \"neo4j+s://-mlruntimedev.databases.neo4j-dev.io\"\n", + "# And paste the database password here\n", + "PASSWORD = \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The usual GDS client initialization\n", + "gds = GraphDataScience(URI, auth=(\"neo4j\", PASSWORD))\n", + "gds.set_database(\"neo4j\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# We will load the Cora dataset\n", + "# The progress bar is sometimes wonky; don't worry about it\n", + "try:\n", + " gds.graph.load_cora()\n", + "except:\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The graph import is completed when this command returns a non-empty list\n", + "gds.graph.list()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "# GNN training!\n", + "\n", + "And now for the exciting stuff!\n", + "In the next cell, you will start a GNN training job.\n", + "In actuality, it is a PyTorch-Geometric GraphSAGE model being trained.\n", + "It happens asynchronously, so it will return immediately (unless there's an unexpected error 😱).\n", + "Of course, the training does not complete instantly, so you will have to wait for it to finish.\n", + "\n", + "## Observing the training progress\n", + "\n", + "You can observe the training progress by watching the logs.\n", + "This is done in the subsequent cell.\n", + "The watching doesn't automatically stop, so you will have to stop it manually.\n", + "Once you see the message 'Training Done', you can interrupt the cell and continue.\n", + "\n", + "## Graph and training parameters\n", + "\n", + "\n", + "\n", + "\n", + "| Parameter | Default | Type | Description |\n", + "|--------------------|----------------|----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| graph_name | - | str | The name of the graph to train on. |\n", + "| model_name | - | str | The name of the model. Must be unique per database and username combination. Models cannot be cleaned up at this time. |\n", + "| feature_properties | - | List[str] | The node properties to use as model features. |\n", + "| target_property | - | str | The node property that contains the target class values. |\n", + "| node_labels | None | List[str] | The node labels to use for training. By default, all labels are used. |\n", + "| relationship_types | None | List[str] | The relationship types to use for training. By default, all types are used. |\n", + "| target_node_label | None | str | Indicates the nodes used for training. Only nodes with this label need to have the `target_property` defined. Other nodes are used for context. By default, all nodes are considered. |\n", + "| graph_sage_config | None | dict | Configuration for the GraphSAGE training. See below. |\n", + "\n", + "\n", + "## GraphSAGE parameters\n", + "\n", + "We have exposed several parameters of the PyG GraphSAGE model.\n", + "\n", + "| Parameter | Default | Description |\n", + "|-----------------|----------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| layer_config | {} | Configuration of the GraphSAGE layers. It supports `aggr`, `normalize`, `root_weight`, `project`, `bias` from [this link](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.SAGEConv.html). Additionally, you can provide message passing configuration from [this link](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.MessagePassing.html#torch_geometric.nn.conv.MessagePassing). |\n", + "| num_neighbors | [25, 10] | Sample sizes for each layer. The length of this list is the number of layers used. All numbers must be >0. |\n", + "| dropout | 0.5 | Probability of dropping out neurons during training. Must be between 0 and 1. |\n", + "| hidden_channels | 256 | The dimension of each hidden layer. Higher value means more expensive training, but higher level of representation. Must be >0. |\n", + "| learning_rate | 0.003 | The learning rate. Must be >0. |\n", + "\n", + "Please try to use any of them with any useful values.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's train!\n", + "job_id = gds.gnn.nodeClassification.train(\n", + " \"cora\", \"myModel\", [\"features\"], \"subject\", [\"CITES\"], target_node_label=\"Paper\", node_labels=[\"Paper\"]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# And let's follow the progress by watching the logs\n", + "gds.gnn.nodeClassification.watch_logs(job_id)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Once the train is completed, we can retrieve the training result (metrics)\n", + "train_result = gds.run_cypher(\"RETURN gds.remoteml.getTrainResult('myModel')\");" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# And display it\n", + "train_result" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "# GNN prediction!\n", + "\n", + "Wow, that was cool.\n", + "But training a model is only half the picture.\n", + "We also have to use it for something.\n", + "In this case, we will use it to predict the subject of papers in the Cora dataset.\n", + "\n", + "Again, this call is asynchronous, so it will return immediately.\n", + "Observe the progress by watching the logs.\n", + "\n", + "Once the prediction is completed, the predicted classes are added to GDS Graph Catalog (as per normal).\n", + "We can retrieve the prediction result (the predictions themselves) by streaming from the graph.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's trigger prediction!\n", + "job_id = gds.gnn.nodeClassification.predict(\"cora\", \"myModel\", \"myPredictions\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# And let's follow progress by watching the logs\n", + "gds.gnn.nodeClassification.watch_logs(job_id)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Now that prediction is done, let's see the predictions\n", + "cora = gds.graph.get(\"cora\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Now for some standard GDS stuff; streaming properties from the graph\n", + "predictions = gds.graph.nodeProperties.stream(\n", + " cora, node_properties=[\"features\", \"myPredictions\"], separate_property_columns=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# And displaying them\n", + "predictions" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "# And that's it!\n", + "\n", + "Thank you very much for participating in the testing.\n", + "We hope you enjoyed it.\n", + "If you've run the notebook for the first time, now's the time to experiment and changing graph, training parameters, etc.\n", + "For example, try out a heterogeneous graph problem? Or whether performance can be improved by changing some parameter? Run training jobs in parallel, on multiple databases?\n", + "If you're feeling like you're done, please reach back to the Google Document and fill in our feedback form.\n", + "\n", + "Thank you!" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/graphdatascience/endpoints.py b/graphdatascience/endpoints.py index 4abd44247..e91c1702b 100644 --- a/graphdatascience/endpoints.py +++ b/graphdatascience/endpoints.py @@ -1,5 +1,6 @@ from .algo.single_mode_algo_endpoints import SingleModeAlgoEndpoints from .call_builder import IndirectAlphaCallBuilder, IndirectBetaCallBuilder +from .gnn.gnn_endpoints import GnnEndpoints from .graph.graph_endpoints import ( GraphAlphaEndpoints, GraphBetaEndpoints, @@ -32,7 +33,9 @@ """ -class DirectEndpoints(DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints): +class DirectEndpoints( + DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints, GnnEndpoints +): def __init__(self, query_runner: QueryRunner, namespace: str, server_version: ServerVersion): super().__init__(query_runner, namespace, server_version) diff --git a/graphdatascience/gnn/__init__.py b/graphdatascience/gnn/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/gnn/gnn_endpoints.py b/graphdatascience/gnn/gnn_endpoints.py new file mode 100644 index 000000000..ba1b7b2b7 --- /dev/null +++ b/graphdatascience/gnn/gnn_endpoints.py @@ -0,0 +1,18 @@ +from ..caller_base import CallerBase +from ..error.illegal_attr_checker import IllegalAttrChecker +from ..error.uncallable_namespace import UncallableNamespace +from .gnn_nc_runner import GNNNodeClassificationRunner + + +class GNNRunner(UncallableNamespace, IllegalAttrChecker): + @property + def nodeClassification(self) -> GNNNodeClassificationRunner: + return GNNNodeClassificationRunner( + self._query_runner, f"{self._namespace}.nodeClassification", self._server_version + ) + + +class GnnEndpoints(CallerBase): + @property + def gnn(self) -> GNNRunner: + return GNNRunner(self._query_runner, f"{self._namespace}.gnn", self._server_version) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py new file mode 100644 index 000000000..14864efcf --- /dev/null +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -0,0 +1,106 @@ +import json +import time +from typing import Any, List + +from ..error.illegal_attr_checker import IllegalAttrChecker +from ..error.uncallable_namespace import UncallableNamespace + + +class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker): + def make_graph_sage_config(self, graph_sage_config): + GRAPH_SAGE_DEFAULT_CONFIG = { + "layer_config": {}, + "num_neighbors": [25, 10], + "dropout": 0.5, + "hidden_channels": 256, + "learning_rate": 0.003, + } + final_sage_config = GRAPH_SAGE_DEFAULT_CONFIG + if graph_sage_config: + bad_keys = [] + for key in graph_sage_config: + if key not in GRAPH_SAGE_DEFAULT_CONFIG: + bad_keys.append(key) + if len(bad_keys) > 0: + raise Exception(f"Argument graph_sage_config contains invalid keys {', '.join(bad_keys)}.") + + final_sage_config.update(graph_sage_config) + return final_sage_config + + def watch_logs(self, job_id: str, logging_interval: int = 5): + print(f"Watching logs of job {job_id}.") + print("This needs to be interrupted manually in order to continue (for example when training is done).") + + def get_logs(offset) -> "Series[Any]": # noqa: F821 + return self._query_runner.run_query( + "RETURN gds.remoteml.getLogs($job_id, $offset)", params={"job_id": job_id, "offset": offset} + ).squeeze() + + received_logs = 0 + training_done = False + while not training_done: + time.sleep(logging_interval) + for log in get_logs(offset=received_logs): + print(log) + received_logs += 1 + return job_id + + def train( + self, + graph_name: str, + model_name: str, + feature_properties: List[str], + target_property: str, + relationship_types: List[str], + target_node_label: str = None, + node_labels: List[str] = None, + graph_sage_config=None, + ) -> str: + mlConfigMap = { + "featureProperties": feature_properties, + "targetProperty": target_property, + "job_type": "train", + "nodeProperties": feature_properties + [target_property], + "relationshipTypes": relationship_types, + "graph_sage_config": self.make_graph_sage_config(graph_sage_config), + } + + if target_node_label: + mlConfigMap["targetNodeLabel"] = target_node_label + if node_labels: + mlConfigMap["nodeLabels"] = node_labels + + mlTrainingConfig = json.dumps(mlConfigMap) + + # token and uri will be injected by arrow_query_runner + job_id = self._query_runner.run_query( + "CALL gds.upload.graph($config) YIELD jobId", + params={ + "config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name}, + }, + ).jobId[0] + + print(f"Started job with jobId={job_id}. Use `gds.gnn.nodeClassification.watch_logs` to track progress.") + return job_id + + def predict( + self, + graph_name: str, + model_name: str, + mutateProperty: str, + predictedProbabilityProperty: str = None, + ) -> str: + mlConfigMap = {"job_type": "predict", "mutateProperty": mutateProperty} + if predictedProbabilityProperty: + mlConfigMap["predictedProbabilityProperty"] = predictedProbabilityProperty + + mlTrainingConfig = json.dumps(mlConfigMap) + job_id = self._query_runner.run_query( + "CALL gds.upload.graph($config) YIELD jobId", + params={ + "config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name}, + }, + ).jobId[0] + + print(f"Started job with jobId={job_id}. Use `gds.gnn.nodeClassification.watch_logs` to track progress.") + return job_id diff --git a/graphdatascience/ignored_server_endpoints.py b/graphdatascience/ignored_server_endpoints.py index 89ad9f0b2..d103a90c4 100644 --- a/graphdatascience/ignored_server_endpoints.py +++ b/graphdatascience/ignored_server_endpoints.py @@ -47,6 +47,7 @@ "gds.alpha.pipeline.nodeRegression.predict.stream", "gds.alpha.pipeline.nodeRegression.selectFeatures", "gds.alpha.pipeline.nodeRegression.train", + "gds.gnn.nc", "gds.similarity.cosine", "gds.similarity.euclidean", "gds.similarity.euclideanDistance", diff --git a/graphdatascience/query_runner/arrow_query_runner.py b/graphdatascience/query_runner/arrow_query_runner.py index cf648879a..eab64398c 100644 --- a/graphdatascience/query_runner/arrow_query_runner.py +++ b/graphdatascience/query_runner/arrow_query_runner.py @@ -29,6 +29,9 @@ def __init__( ): self._fallback_query_runner = fallback_query_runner self._server_version = server_version + # FIXME handle version were tls cert is given + self._auth = auth + self._uri = uri host, port_string = uri.split(":") @@ -39,8 +42,9 @@ def __init__( ) client_options: Dict[str, Any] = {"disable_server_verification": disable_server_verification} + self._auth_factory = AuthFactory(auth) if auth: - client_options["middleware"] = [AuthFactory(auth)] + client_options["middleware"] = [self._auth_factory] if tls_root_certs: client_options["tls_root_certs"] = tls_root_certs @@ -129,6 +133,10 @@ def run_query( endpoint = "gds.beta.graph.relationships.stream" return self._run_arrow_property_get(graph_name, endpoint, {"relationship_types": relationship_types}) + elif "gds.upload.graph" in query: + # inject parameters + params["config"]["token"] = self._get_or_request_token() + params["config"]["arrowEndpoint"] = self._uri return self._fallback_query_runner.run_query(query, params, database, custom_error) @@ -184,6 +192,10 @@ def create_graph_constructor( database, graph_name, self._flight_client, concurrency, undirected_relationship_types ) + def _get_or_request_token(self) -> str: + self._flight_client.authenticate_basic_token(self._auth[0], self._auth[1]) + return self._auth_factory.token() + class AuthFactory(ClientMiddlewareFactory): # type: ignore def __init__(self, auth: Tuple[str, str], *args: Any, **kwargs: Any) -> None: @@ -217,9 +229,14 @@ def __init__(self, factory: AuthFactory, *args: Any, **kwargs: Any) -> None: self._factory = factory def received_headers(self, headers: Dict[str, Any]) -> None: - auth_header: str = headers.get("Authorization", None) + auth_header: str = headers.get("authorization", None) if not auth_header: return + # authenticate_basic_token() returns a list. + # TODO We should take the first Bearer element here + if isinstance(auth_header, list): + auth_header = auth_header[0] + [auth_type, token] = auth_header.split(" ", 1) if auth_type == "Bearer": self._factory.set_token(token)