From 7cbb64bfd1d2d99ec606aee008250486f7c74567 Mon Sep 17 00:00:00 2001 From: Jacob Sznajdman Date: Thu, 22 Jun 2023 12:11:14 +0200 Subject: [PATCH 01/20] Implement client endpoints for gnn/graph sage training --- graphdatascience/endpoints.py | 3 ++- graphdatascience/gnn/__init__.py | 0 graphdatascience/gnn/gnn_endpoints.py | 17 ++++++++++++++++ graphdatascience/gnn/gnn_nc_runner.py | 21 ++++++++++++++++++++ graphdatascience/ignored_server_endpoints.py | 1 + 5 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 graphdatascience/gnn/__init__.py create mode 100644 graphdatascience/gnn/gnn_endpoints.py create mode 100644 graphdatascience/gnn/gnn_nc_runner.py diff --git a/graphdatascience/endpoints.py b/graphdatascience/endpoints.py index 4abd44247..8df5e5e30 100644 --- a/graphdatascience/endpoints.py +++ b/graphdatascience/endpoints.py @@ -1,5 +1,6 @@ from .algo.single_mode_algo_endpoints import SingleModeAlgoEndpoints from .call_builder import IndirectAlphaCallBuilder, IndirectBetaCallBuilder +from .gnn.gnn_endpoints import GnnEndpoints from .graph.graph_endpoints import ( GraphAlphaEndpoints, GraphBetaEndpoints, @@ -32,7 +33,7 @@ """ -class DirectEndpoints(DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints): +class DirectEndpoints(DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints, GnnEndpoints): def __init__(self, query_runner: QueryRunner, namespace: str, server_version: ServerVersion): super().__init__(query_runner, namespace, server_version) diff --git a/graphdatascience/gnn/__init__.py b/graphdatascience/gnn/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/gnn/gnn_endpoints.py b/graphdatascience/gnn/gnn_endpoints.py new file mode 100644 index 000000000..e140eb8fc --- /dev/null +++ b/graphdatascience/gnn/gnn_endpoints.py @@ -0,0 +1,17 @@ +from .gnn_nc_runner import GNNNodeClassificationRunner +from ..caller_base import CallerBase +from ..error.illegal_attr_checker import IllegalAttrChecker +from ..error.uncallable_namespace import UncallableNamespace + +class GNNRunner(UncallableNamespace, IllegalAttrChecker): + @property + def nodeClassification(self) -> GNNNodeClassificationRunner: + return GNNNodeClassificationRunner(self._query_runner, f"{self._namespace}.nodeClassification", self._server_version) + +class GnnEndpoints(CallerBase): + @property + def gnn(self) -> GNNRunner: + return GNNRunner(self._query_runner, f"{self._namespace}.gnn", self._server_version) + + + diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py new file mode 100644 index 000000000..d898176f0 --- /dev/null +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -0,0 +1,21 @@ +from typing import Any, List + +from ..error.illegal_attr_checker import IllegalAttrChecker +from ..error.uncallable_namespace import UncallableNamespace +import json + + +class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker): + def train(self, graph_name: str, model_name: str, feature_properties: List[str], target_property: str, + target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]": + configMap = { + "featureProperties": feature_properties, + "targetProperty": target_property, + } + node_properties = feature_properties + [target_property] + if target_node_label: + configMap["targetNodeLabel"] = target_node_label + mlTrainingConfig = json.dumps(configMap) + # TODO query avaiable node labels + node_labels = ["Paper"] if not node_labels else node_labels + self._query_runner.run_query(f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {node_properties}}})") diff --git a/graphdatascience/ignored_server_endpoints.py b/graphdatascience/ignored_server_endpoints.py index 89ad9f0b2..d103a90c4 100644 --- a/graphdatascience/ignored_server_endpoints.py +++ b/graphdatascience/ignored_server_endpoints.py @@ -47,6 +47,7 @@ "gds.alpha.pipeline.nodeRegression.predict.stream", "gds.alpha.pipeline.nodeRegression.selectFeatures", "gds.alpha.pipeline.nodeRegression.train", + "gds.gnn.nc", "gds.similarity.cosine", "gds.similarity.euclidean", "gds.similarity.euclideanDistance", From 5117eea82fcaebc89605740a8d23dc2241b076a2 Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Thu, 6 Jul 2023 17:41:52 +0200 Subject: [PATCH 02/20] Add predict endpoint to GNN NC runner Co-authored-by: Jacob Sznajdman Co-authored-by: Olga Razvenskaia --- graphdatascience/gnn/gnn_nc_runner.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index d898176f0..64dbb44b1 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -11,11 +11,26 @@ def train(self, graph_name: str, model_name: str, feature_properties: List[str], configMap = { "featureProperties": feature_properties, "targetProperty": target_property, + "job_type": "train", } node_properties = feature_properties + [target_property] if target_node_label: configMap["targetNodeLabel"] = target_node_label mlTrainingConfig = json.dumps(configMap) - # TODO query avaiable node labels + # TODO query available node labels node_labels = ["Paper"] if not node_labels else node_labels self._query_runner.run_query(f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {node_properties}}})") + + + def predict(self, graph_name: str, model_name: str, feature_properties: List[str], target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]": + configMap = { + "featureProperties": feature_properties, + "job_type": "predict", + } + if target_node_label: + configMap["targetNodeLabel"] = target_node_label + mlTrainingConfig = json.dumps(configMap) + # TODO query available node labels + node_labels = ["Paper"] if not node_labels else node_labels + self._query_runner.run_query( + f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {feature_properties}}})") From 1e8237e19d4adb0d54b514b6e841fedea13627f6 Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Thu, 6 Jul 2023 18:00:12 +0200 Subject: [PATCH 03/20] Add notebook illustrating usage of new GNN stuff --- examples/python-runtime.ipynb | 83 +++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 examples/python-runtime.ipynb diff --git a/examples/python-runtime.ipynb b/examples/python-runtime.ipynb new file mode 100644 index 000000000..d0a44265f --- /dev/null +++ b/examples/python-runtime.ipynb @@ -0,0 +1,83 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "DBID = \"beefbeef\"\n", + "ENVIRONMENT = \"\"\n", + "PASSWORD = \"\"\n", + "\n", + "from graphdatascience import GraphDataScience\n", + "\n", + "gds = GraphDataScience(\n", + " f\"neo4j+s://{DBID}-{ENVIRONMENT}.databases.neo4j-dev.io/\", auth=(\"neo4j\", PASSWORD)\n", + ")\n", + "gds.set_database(\"neo4j\")\n", + "\n", + "gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "try:\n", + " gds.graph.load_cora()\n", + "except:\n", + " pass\n" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])\n" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "gds.gnn.nodeClassification.predict(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])" + ], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} From 95ecaa5ffa5aa3c1eba05570bd384ad2165e473b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florentin=20D=C3=B6rre?= Date: Tue, 11 Jul 2023 11:10:07 +0200 Subject: [PATCH 04/20] Inject Arrow credentials to UploadGraph endpoint Co-authored-by: Brian Shi Co-authored-by: Olga Razvenskaia --- examples/python-runtime.ipynb | 48 +++++-------------- graphdatascience/endpoints.py | 4 +- graphdatascience/gnn/gnn_endpoints.py | 11 +++-- graphdatascience/gnn/gnn_nc_runner.py | 13 ++++- .../query_runner/arrow_query_runner.py | 31 +++++++++++- 5 files changed, 62 insertions(+), 45 deletions(-) diff --git a/examples/python-runtime.ipynb b/examples/python-runtime.ipynb index d0a44265f..ca7d4c2f5 100644 --- a/examples/python-runtime.ipynb +++ b/examples/python-runtime.ipynb @@ -3,9 +3,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "DBID = \"beefbeef\"\n", @@ -14,68 +12,46 @@ "\n", "from graphdatascience import GraphDataScience\n", "\n", - "gds = GraphDataScience(\n", - " f\"neo4j+s://{DBID}-{ENVIRONMENT}.databases.neo4j-dev.io/\", auth=(\"neo4j\", PASSWORD)\n", - ")\n", + "gds = GraphDataScience(f\"neo4j+s://{DBID}-{ENVIRONMENT}.databases.neo4j-dev.io/\", auth=(\"neo4j\", PASSWORD))\n", "gds.set_database(\"neo4j\")\n", "\n", - "gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])\n" + "gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])" ] }, { "cell_type": "code", "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "try:\n", " gds.graph.load_cora()\n", "except:\n", - " pass\n" - ], - "metadata": { - "collapsed": false - } + " pass" + ] }, { "cell_type": "code", "execution_count": null, + "metadata": {}, "outputs": [], "source": [ - "gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])\n" - ], - "metadata": { - "collapsed": false - } + "gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])" + ] }, { "cell_type": "code", "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "gds.gnn.nodeClassification.predict(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])" - ], - "metadata": { - "collapsed": false - } + ] } ], "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" + "name": "python" } }, "nbformat": 4, diff --git a/graphdatascience/endpoints.py b/graphdatascience/endpoints.py index 8df5e5e30..e91c1702b 100644 --- a/graphdatascience/endpoints.py +++ b/graphdatascience/endpoints.py @@ -33,7 +33,9 @@ """ -class DirectEndpoints(DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints, GnnEndpoints): +class DirectEndpoints( + DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints, GnnEndpoints +): def __init__(self, query_runner: QueryRunner, namespace: str, server_version: ServerVersion): super().__init__(query_runner, namespace, server_version) diff --git a/graphdatascience/gnn/gnn_endpoints.py b/graphdatascience/gnn/gnn_endpoints.py index e140eb8fc..ba1b7b2b7 100644 --- a/graphdatascience/gnn/gnn_endpoints.py +++ b/graphdatascience/gnn/gnn_endpoints.py @@ -1,17 +1,18 @@ -from .gnn_nc_runner import GNNNodeClassificationRunner from ..caller_base import CallerBase from ..error.illegal_attr_checker import IllegalAttrChecker from ..error.uncallable_namespace import UncallableNamespace +from .gnn_nc_runner import GNNNodeClassificationRunner + class GNNRunner(UncallableNamespace, IllegalAttrChecker): @property def nodeClassification(self) -> GNNNodeClassificationRunner: - return GNNNodeClassificationRunner(self._query_runner, f"{self._namespace}.nodeClassification", self._server_version) + return GNNNodeClassificationRunner( + self._query_runner, f"{self._namespace}.nodeClassification", self._server_version + ) + class GnnEndpoints(CallerBase): @property def gnn(self) -> GNNRunner: return GNNRunner(self._query_runner, f"{self._namespace}.gnn", self._server_version) - - - diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index 64dbb44b1..54545c020 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -13,13 +13,24 @@ def train(self, graph_name: str, model_name: str, feature_properties: List[str], "targetProperty": target_property, "job_type": "train", } + node_properties = feature_properties + [target_property] if target_node_label: configMap["targetNodeLabel"] = target_node_label mlTrainingConfig = json.dumps(configMap) # TODO query available node labels node_labels = ["Paper"] if not node_labels else node_labels - self._query_runner.run_query(f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {node_properties}}})") + + # token and uri will be injected by arrow_query_runner + self._query_runner.run_query( + f"CALL gds.upload.graph($graph_name, $config)", + params={"graph_name": graph_name, "config": { + "mlTrainingConfig": mlTrainingConfig, + "modelName": model_name, + "nodeLabels": node_labels, + "nodeProperties": node_properties + }} + ) def predict(self, graph_name: str, model_name: str, feature_properties: List[str], target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]": diff --git a/graphdatascience/query_runner/arrow_query_runner.py b/graphdatascience/query_runner/arrow_query_runner.py index cf648879a..acb1b1325 100644 --- a/graphdatascience/query_runner/arrow_query_runner.py +++ b/graphdatascience/query_runner/arrow_query_runner.py @@ -29,6 +29,9 @@ def __init__( ): self._fallback_query_runner = fallback_query_runner self._server_version = server_version + # FIXME handle version were tls cert is given + self._auth = auth + self._uri = uri host, port_string = uri.split(":") @@ -39,8 +42,9 @@ def __init__( ) client_options: Dict[str, Any] = {"disable_server_verification": disable_server_verification} + self._auth_factory = AuthFactory(auth) if auth: - client_options["middleware"] = [AuthFactory(auth)] + client_options["middleware"] = [self._auth_factory] if tls_root_certs: client_options["tls_root_certs"] = tls_root_certs @@ -129,6 +133,11 @@ def run_query( endpoint = "gds.beta.graph.relationships.stream" return self._run_arrow_property_get(graph_name, endpoint, {"relationship_types": relationship_types}) + elif "gds.upload.graph" in query: + # inject parameters + params["config"]["token"] = self._get_or_request_token() + params["config"]["arrowEndpoint"] = self._uri + print(params) return self._fallback_query_runner.run_query(query, params, database, custom_error) @@ -183,11 +192,19 @@ def create_graph_constructor( return ArrowGraphConstructor( database, graph_name, self._flight_client, concurrency, undirected_relationship_types ) + + def _get_or_request_token(self) -> str: + print("get or request token") + self._flight_client.authenticate_basic_token(self._auth[0], self._auth[1]) + return self._auth_factory.token() class AuthFactory(ClientMiddlewareFactory): # type: ignore def __init__(self, auth: Tuple[str, str], *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) + print("init auth factory") + + self._auth = auth self._token: Optional[str] = None self._token_timestamp = 0 @@ -196,6 +213,7 @@ def start_call(self, info: Any) -> "AuthMiddleware": return AuthMiddleware(self) def token(self) -> Optional[str]: + print(f"current token {self._token} at {self._token_timestamp}") # check whether the token is older than 10 minutes. If so, reset it. if self._token and int(time.time()) - self._token_timestamp > 600: self._token = None @@ -206,6 +224,8 @@ def set_token(self, token: str) -> None: self._token = token self._token_timestamp = int(time.time()) + print(f"set token {self._token} time_stamp: {self._token_timestamp}") + @property def auth(self) -> Tuple[str, str]: return self._auth @@ -217,14 +237,21 @@ def __init__(self, factory: AuthFactory, *args: Any, **kwargs: Any) -> None: self._factory = factory def received_headers(self, headers: Dict[str, Any]) -> None: - auth_header: str = headers.get("Authorization", None) + auth_header: str = headers.get("authorization", None) if not auth_header: return + # authenticate_basic_token() returns a list. + # TODO We should take the first Bearer element here + if isinstance(auth_header, list): + auth_header = auth_header[0] + [auth_type, token] = auth_header.split(" ", 1) if auth_type == "Bearer": self._factory.set_token(token) def sending_headers(self) -> Dict[str, str]: + print("sending headers") + token = self._factory.token() if not token: username, password = self._factory.auth From f1977a30b37f992748010a8ea2e86d3b499d0063 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florentin=20D=C3=B6rre?= Date: Tue, 11 Jul 2023 17:28:49 +0200 Subject: [PATCH 05/20] Remove prints --- graphdatascience/query_runner/arrow_query_runner.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/graphdatascience/query_runner/arrow_query_runner.py b/graphdatascience/query_runner/arrow_query_runner.py index acb1b1325..b13565858 100644 --- a/graphdatascience/query_runner/arrow_query_runner.py +++ b/graphdatascience/query_runner/arrow_query_runner.py @@ -137,7 +137,6 @@ def run_query( # inject parameters params["config"]["token"] = self._get_or_request_token() params["config"]["arrowEndpoint"] = self._uri - print(params) return self._fallback_query_runner.run_query(query, params, database, custom_error) @@ -194,7 +193,6 @@ def create_graph_constructor( ) def _get_or_request_token(self) -> str: - print("get or request token") self._flight_client.authenticate_basic_token(self._auth[0], self._auth[1]) return self._auth_factory.token() @@ -202,9 +200,6 @@ def _get_or_request_token(self) -> str: class AuthFactory(ClientMiddlewareFactory): # type: ignore def __init__(self, auth: Tuple[str, str], *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - print("init auth factory") - - self._auth = auth self._token: Optional[str] = None self._token_timestamp = 0 @@ -213,7 +208,6 @@ def start_call(self, info: Any) -> "AuthMiddleware": return AuthMiddleware(self) def token(self) -> Optional[str]: - print(f"current token {self._token} at {self._token_timestamp}") # check whether the token is older than 10 minutes. If so, reset it. if self._token and int(time.time()) - self._token_timestamp > 600: self._token = None @@ -224,8 +218,6 @@ def set_token(self, token: str) -> None: self._token = token self._token_timestamp = int(time.time()) - print(f"set token {self._token} time_stamp: {self._token_timestamp}") - @property def auth(self) -> Tuple[str, str]: return self._auth @@ -250,8 +242,6 @@ def received_headers(self, headers: Dict[str, Any]) -> None: self._factory.set_token(token) def sending_headers(self) -> Dict[str, str]: - print("sending headers") - token = self._factory.token() if not token: username, password = self._factory.auth From ae2af59363f15071be143fd0185f17c3da964564 Mon Sep 17 00:00:00 2001 From: Brian Shi Date: Wed, 12 Jul 2023 17:10:01 +0100 Subject: [PATCH 06/20] Add all configs to CRD MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Florentin Dörre Co-authored-by: Olga Razvenskaia --- graphdatascience/gnn/gnn_nc_runner.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index 54545c020..8eccd6c88 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -12,23 +12,22 @@ def train(self, graph_name: str, model_name: str, feature_properties: List[str], "featureProperties": feature_properties, "targetProperty": target_property, "job_type": "train", + "nodeProperties": feature_properties + [target_property] } - node_properties = feature_properties + [target_property] if target_node_label: configMap["targetNodeLabel"] = target_node_label + if node_labels: + configMap["nodeLabels"] = node_labels + mlTrainingConfig = json.dumps(configMap) - # TODO query available node labels - node_labels = ["Paper"] if not node_labels else node_labels # token and uri will be injected by arrow_query_runner self._query_runner.run_query( f"CALL gds.upload.graph($graph_name, $config)", params={"graph_name": graph_name, "config": { "mlTrainingConfig": mlTrainingConfig, - "modelName": model_name, - "nodeLabels": node_labels, - "nodeProperties": node_properties + "modelName": model_name }} ) From 8a681b422d5124b614c06931af0a7e352ad91261 Mon Sep 17 00:00:00 2001 From: Brian Shi Date: Mon, 17 Jul 2023 11:35:25 +0100 Subject: [PATCH 07/20] Cleanup nc_runner MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Florentin Dörre Co-authored-by: Olga Razvenskaia --- graphdatascience/gnn/gnn_nc_runner.py | 64 ++++++++++++------- .../query_runner/arrow_query_runner.py | 4 +- 2 files changed, 44 insertions(+), 24 deletions(-) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index 8eccd6c88..adb49fb07 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -1,46 +1,66 @@ +import json from typing import Any, List from ..error.illegal_attr_checker import IllegalAttrChecker from ..error.uncallable_namespace import UncallableNamespace -import json class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker): - def train(self, graph_name: str, model_name: str, feature_properties: List[str], target_property: str, - target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]": - configMap = { + def train( + self, + graph_name: str, + model_name: str, + feature_properties: List[str], + target_property: str, + target_node_label: str = None, + node_labels: List[str] = None, + ) -> "Series[Any]": # noqa: F821 + mlConfigMap = { "featureProperties": feature_properties, "targetProperty": target_property, "job_type": "train", - "nodeProperties": feature_properties + [target_property] + "nodeProperties": feature_properties + [target_property], } if target_node_label: - configMap["targetNodeLabel"] = target_node_label + mlConfigMap["targetNodeLabel"] = target_node_label if node_labels: - configMap["nodeLabels"] = node_labels + mlConfigMap["nodeLabels"] = node_labels - mlTrainingConfig = json.dumps(configMap) + mlTrainingConfig = json.dumps(mlConfigMap) # token and uri will be injected by arrow_query_runner self._query_runner.run_query( - f"CALL gds.upload.graph($graph_name, $config)", - params={"graph_name": graph_name, "config": { - "mlTrainingConfig": mlTrainingConfig, - "modelName": model_name - }} - ) - + "CALL gds.upload.graph($graph_name, $config)", + params={ + "graph_name": graph_name, + "config": {"mlTrainingConfig": mlTrainingConfig, "modelName": model_name}, + }, + ) - def predict(self, graph_name: str, model_name: str, feature_properties: List[str], target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]": - configMap = { + def predict( + self, + graph_name: str, + model_name: str, + feature_properties: List[str], + target_node_label: str = None, + node_labels: List[str] = None, + ) -> "Series[Any]": # noqa: F821 + mlConfigMap = { "featureProperties": feature_properties, "job_type": "predict", + "nodeProperties": feature_properties, } if target_node_label: - configMap["targetNodeLabel"] = target_node_label - mlTrainingConfig = json.dumps(configMap) - # TODO query available node labels - node_labels = ["Paper"] if not node_labels else node_labels + mlConfigMap["targetNodeLabel"] = target_node_label + if node_labels: + mlConfigMap["nodeLabels"] = node_labels + + mlTrainingConfig = json.dumps(mlConfigMap) self._query_runner.run_query( - f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {feature_properties}}})") + "CALL gds.upload.graph($graph_name, $config)", + params={ + "graph_name": graph_name, + "config": {"mlTrainingConfig": mlTrainingConfig, "modelName": model_name}, + }, + ) # type: ignore diff --git a/graphdatascience/query_runner/arrow_query_runner.py b/graphdatascience/query_runner/arrow_query_runner.py index b13565858..eab64398c 100644 --- a/graphdatascience/query_runner/arrow_query_runner.py +++ b/graphdatascience/query_runner/arrow_query_runner.py @@ -191,7 +191,7 @@ def create_graph_constructor( return ArrowGraphConstructor( database, graph_name, self._flight_client, concurrency, undirected_relationship_types ) - + def _get_or_request_token(self) -> str: self._flight_client.authenticate_basic_token(self._auth[0], self._auth[1]) return self._auth_factory.token() @@ -232,7 +232,7 @@ def received_headers(self, headers: Dict[str, Any]) -> None: auth_header: str = headers.get("authorization", None) if not auth_header: return - # authenticate_basic_token() returns a list. + # authenticate_basic_token() returns a list. # TODO We should take the first Bearer element here if isinstance(auth_header, list): auth_header = auth_header[0] From 84b8cef6677fae3b7b6f8236e50ce82e4b30f19e Mon Sep 17 00:00:00 2001 From: Brian Shi Date: Tue, 18 Jul 2023 16:03:44 +0100 Subject: [PATCH 08/20] Parse remote ML configs --- graphdatascience/gnn/gnn_nc_runner.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index adb49fb07..a18f4499d 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -12,6 +12,7 @@ def train( model_name: str, feature_properties: List[str], target_property: str, + relationship_types: List[str], target_node_label: str = None, node_labels: List[str] = None, ) -> "Series[Any]": # noqa: F821 @@ -20,6 +21,7 @@ def train( "targetProperty": target_property, "job_type": "train", "nodeProperties": feature_properties + [target_property], + "relationshipTypes": relationship_types } if target_node_label: @@ -31,10 +33,9 @@ def train( # token and uri will be injected by arrow_query_runner self._query_runner.run_query( - "CALL gds.upload.graph($graph_name, $config)", + "CALL gds.upload.graph($config)", params={ - "graph_name": graph_name, - "config": {"mlTrainingConfig": mlTrainingConfig, "modelName": model_name}, + "config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name}, }, ) @@ -43,6 +44,7 @@ def predict( graph_name: str, model_name: str, feature_properties: List[str], + relationship_types: List[str], target_node_label: str = None, node_labels: List[str] = None, ) -> "Series[Any]": # noqa: F821 @@ -50,6 +52,7 @@ def predict( "featureProperties": feature_properties, "job_type": "predict", "nodeProperties": feature_properties, + "relationshipTypes": relationship_types } if target_node_label: mlConfigMap["targetNodeLabel"] = target_node_label @@ -58,9 +61,8 @@ def predict( mlTrainingConfig = json.dumps(mlConfigMap) self._query_runner.run_query( - "CALL gds.upload.graph($graph_name, $config)", + "CALL gds.upload.graph($config)", params={ - "graph_name": graph_name, - "config": {"mlTrainingConfig": mlTrainingConfig, "modelName": model_name}, + "config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name}, }, ) # type: ignore From 82087b7829b5dfad122108784ccc90cb7c102742 Mon Sep 17 00:00:00 2001 From: Brian Shi Date: Thu, 20 Jul 2023 15:30:11 +0100 Subject: [PATCH 09/20] Add mutateProperty MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Florentin Dörre Co-authored-by: Jacob Sznajdman Co-authored-by: Olga Razvenskaia --- graphdatascience/gnn/gnn_nc_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index a18f4499d..261443a87 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -45,6 +45,7 @@ def predict( model_name: str, feature_properties: List[str], relationship_types: List[str], + mutateProperty: str, target_node_label: str = None, node_labels: List[str] = None, ) -> "Series[Any]": # noqa: F821 @@ -52,7 +53,8 @@ def predict( "featureProperties": feature_properties, "job_type": "predict", "nodeProperties": feature_properties, - "relationshipTypes": relationship_types + "relationshipTypes": relationship_types, + "mutateProperty": mutateProperty } if target_node_label: mlConfigMap["targetNodeLabel"] = target_node_label From 8e822bed26c57e4ecbb0b27d50418bef49b2c358 Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Mon, 31 Jul 2023 16:12:38 +0100 Subject: [PATCH 10/20] Short form of predict call Co-authored-by: Jacob Sznajdman --- graphdatascience/gnn/gnn_nc_runner.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index 261443a87..90ee4b56b 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -43,23 +43,15 @@ def predict( self, graph_name: str, model_name: str, - feature_properties: List[str], - relationship_types: List[str], mutateProperty: str, - target_node_label: str = None, - node_labels: List[str] = None, + predictedProbabilityProperty: str = None, ) -> "Series[Any]": # noqa: F821 mlConfigMap = { - "featureProperties": feature_properties, "job_type": "predict", - "nodeProperties": feature_properties, - "relationshipTypes": relationship_types, "mutateProperty": mutateProperty } - if target_node_label: - mlConfigMap["targetNodeLabel"] = target_node_label - if node_labels: - mlConfigMap["nodeLabels"] = node_labels + if predictedProbabilityProperty: + mlConfigMap["predictedProbabilityProperty"] = predictedProbabilityProperty mlTrainingConfig = json.dumps(mlConfigMap) self._query_runner.run_query( From 74d563975a55a85252b7be3ee6a07887bd171846 Mon Sep 17 00:00:00 2001 From: Jacob Sznajdman Date: Thu, 27 Jul 2023 17:56:08 +0200 Subject: [PATCH 11/20] Expose graphsage training configuration Co-authored-by: Olga Razvenskaia --- graphdatascience/gnn/gnn_nc_runner.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index 90ee4b56b..37f7125d9 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -6,6 +6,21 @@ class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker): + def make_graph_sage_config(self, graph_sage_config): + GRAPH_SAGE_DEFAULT_CONFIG = {"layer_config": {}, "num_neighbors": [25, 10], "dropout": 0.5, + "hidden_channels": 256} + final_sage_config = GRAPH_SAGE_DEFAULT_CONFIG + if graph_sage_config: + bad_keys = [] + for key in graph_sage_config: + if key not in GRAPH_SAGE_DEFAULT_CONFIG: + bad_keys.append(key) + if len(bad_keys) > 0: + raise Exception(f"Argument graph_sage_config contains invalid keys {', '.join(bad_keys)}.") + + final_sage_config.update(graph_sage_config) + return final_sage_config + def train( self, graph_name: str, @@ -15,13 +30,15 @@ def train( relationship_types: List[str], target_node_label: str = None, node_labels: List[str] = None, + graph_sage_config = None ) -> "Series[Any]": # noqa: F821 mlConfigMap = { "featureProperties": feature_properties, "targetProperty": target_property, "job_type": "train", "nodeProperties": feature_properties + [target_property], - "relationshipTypes": relationship_types + "relationshipTypes": relationship_types, + "graph_sage_config": self.make_graph_sage_config(graph_sage_config) } if target_node_label: From fa544887ea83a68bb5005411e2b1dc4d43983be3 Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Tue, 1 Aug 2023 11:19:15 +0100 Subject: [PATCH 12/20] Add learning rate Co-authored-by: Jacob Sznajdman --- graphdatascience/gnn/gnn_nc_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index 37f7125d9..27aec8d63 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -8,7 +8,7 @@ class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker): def make_graph_sage_config(self, graph_sage_config): GRAPH_SAGE_DEFAULT_CONFIG = {"layer_config": {}, "num_neighbors": [25, 10], "dropout": 0.5, - "hidden_channels": 256} + "hidden_channels": 256, "learning_rate": 0.003} final_sage_config = GRAPH_SAGE_DEFAULT_CONFIG if graph_sage_config: bad_keys = [] From e6849590b7b3100585bc1262a93b98bd178569e0 Mon Sep 17 00:00:00 2001 From: Brian Shi Date: Tue, 1 Aug 2023 16:15:09 +0100 Subject: [PATCH 13/20] Update python-runtime notebook --- examples/python-runtime.ipynb | 167 ++++++++++++++++++++++++++++++---- 1 file changed, 148 insertions(+), 19 deletions(-) diff --git a/examples/python-runtime.ipynb b/examples/python-runtime.ipynb index ca7d4c2f5..da2118da3 100644 --- a/examples/python-runtime.ipynb +++ b/examples/python-runtime.ipynb @@ -2,27 +2,68 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "DBID = \"beefbeef\"\n", - "ENVIRONMENT = \"\"\n", - "PASSWORD = \"\"\n", - "\n", - "from graphdatascience import GraphDataScience\n", - "\n", - "gds = GraphDataScience(f\"neo4j+s://{DBID}-{ENVIRONMENT}.databases.neo4j-dev.io/\", auth=(\"neo4j\", PASSWORD))\n", - "gds.set_database(\"neo4j\")\n", - "\n", - "gds.gnn.nodeClassification.train(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])" + "from graphdatascience import GraphDataScience" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 2, + "outputs": [], + "source": [ + "ENVIRONMENT = \"mlruntimedev\"\n", + "DBID = \"e6ba1b5c\"\n", + "PASSWORD = \"l4Co2Qa5GseW0sMropCvJo17laf6ZCq9vuAhiJrVW2c\"" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 3, "outputs": [], + "source": [ + "gds = GraphDataScience(f\"neo4j+s://{DBID}-{ENVIRONMENT}.databases.neo4j-dev.io/\", auth=(\"neo4j\", PASSWORD))\n", + "gds.set_database(\"neo4j\")" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": "Uploading Nodes: 0%| | 0/2708 [00:00\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n
gds.remoteml.getTrainResult('model2')
0{'test_acc_mean': 0.8589511513710022, 'test_ac...
\n" + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_result" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ - "gds.gnn.nodeClassification.predict(\"cora\", \"model\", [\"features\"], \"subject\", node_labels=[\"Paper\"])" + "predict_result = gds.gnn.nodeClassification.predict(\"cora\", \"myModel\", \"myPredictions\")" ] + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [], + "source": [ + "cora = gds.graph.get('cora')" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 11, + "outputs": [], + "source": [ + "predictions = gds.graph.nodeProperties.stream(cora, node_properties=[\"features\", \"myPredictions\"], separate_property_columns=True)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 12, + "outputs": [ + { + "data": { + "text/plain": " nodeId features \\\n0 31336 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n1 1061127 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ... \n2 1106406 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n3 13195 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n4 37879 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n... ... ... \n2703 1128975 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n2704 1128977 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n2705 1128978 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n2706 117328 [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n2707 24043 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n\n model2Predictions \n0 0 \n1 1 \n2 2 \n3 2 \n4 3 \n... ... \n2703 5 \n2704 5 \n2705 5 \n2706 6 \n2707 0 \n\n[2708 rows x 3 columns]", + "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
nodeIdfeaturesmodel2Predictions
031336[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...0
11061127[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ...1
21106406[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...2
313195[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...2
437879[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...3
............
27031128975[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...5
27041128977[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...5
27051128978[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...5
2706117328[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...6
270724043[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...0
\n

2708 rows × 3 columns

\n
" + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "predictions" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + } } ], "metadata": { From b556c266cd2d581cf332ce3fafdddcc4568515f6 Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Fri, 4 Aug 2023 16:48:17 +0200 Subject: [PATCH 14/20] Adapt notebook to field testing experience --- examples/python-runtime.ipynb | 222 +++++++++++++++++++--------------- 1 file changed, 122 insertions(+), 100 deletions(-) diff --git a/examples/python-runtime.ipynb b/examples/python-runtime.ipynb index da2118da3..952bb7f5e 100644 --- a/examples/python-runtime.ipynb +++ b/examples/python-runtime.ipynb @@ -2,69 +2,45 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ + "# Make sure you have installed the custom GDS Client distributed with this notebook\n", "from graphdatascience import GraphDataScience" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ - "ENVIRONMENT = \"mlruntimedev\"\n", - "DBID = \"e6ba1b5c\"\n", - "PASSWORD = \"l4Co2Qa5GseW0sMropCvJo17laf6ZCq9vuAhiJrVW2c\"" - ], - "metadata": { - "collapsed": false - } + "# From the Aura Console, get the Connection URI to your Neo4j instance and paste here\n", + "URI = \"neo4j+s://-mlruntimedev.databases.neo4j-dev.io\"\n", + "# And paste the database password here\n", + "PASSWORD = \"\"" + ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ - "gds = GraphDataScience(f\"neo4j+s://{DBID}-{ENVIRONMENT}.databases.neo4j-dev.io/\", auth=(\"neo4j\", PASSWORD))\n", + "# The usual GDS client initialization\n", + "gds = GraphDataScience(URI, auth=(\"neo4j\", PASSWORD))\n", "gds.set_database(\"neo4j\")" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": "Uploading Nodes: 0%| | 0/2708 [00:00\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n
gds.remoteml.getTrainResult('model2')
0{'test_acc_mean': 0.8589511513710022, 'test_ac...
\n" - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ + "# And display it\n", "train_result" - ], + ] + }, + { + "cell_type": "markdown", "metadata": { "collapsed": false - } + }, + "source": [ + "# GNN prediction!\n", + "\n", + "Wow, that was cool.\n", + "But training a model is only half the picture.\n", + "We also have to use it for something.\n", + "In this case, we will use it to predict the subject of papers in the Cora dataset.\n", + "\n", + "Again, this call is asynchronous, so it will return immediately.\n", + "\n", + "TODO: instructions for inspecting the log\n", + "\n", + "Once the prediction is completed, the predicted classes are added to GDS Graph Catalog (as per normal).\n", + "We can retrieve the prediction result (the predictions themselves) by streaming from the graph.\n" + ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ + "# Let's trigger prediction!\n", "predict_result = gds.gnn.nodeClassification.predict(\"cora\", \"myModel\", \"myPredictions\")" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ - "cora = gds.graph.get('cora')" - ], - "metadata": { - "collapsed": false - } + "# Let's get a graph object\n", + "cora = gds.graph.get(\"cora\")" + ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ - "predictions = gds.graph.nodeProperties.stream(cora, node_properties=[\"features\", \"myPredictions\"], separate_property_columns=True)" - ], - "metadata": { - "collapsed": false - } + "# Now for some standard GDS stuff; streaming properties from the graph\n", + "predictions = gds.graph.nodeProperties.stream(\n", + " cora, node_properties=[\"features\", \"myPredictions\"], separate_property_columns=True\n", + ")" + ] }, { "cell_type": "code", - "execution_count": 12, - "outputs": [ - { - "data": { - "text/plain": " nodeId features \\\n0 31336 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n1 1061127 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ... \n2 1106406 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n3 13195 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n4 37879 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n... ... ... \n2703 1128975 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n2704 1128977 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n2705 1128978 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n2706 117328 [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n2707 24043 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n\n model2Predictions \n0 0 \n1 1 \n2 2 \n3 2 \n4 3 \n... ... \n2703 5 \n2704 5 \n2705 5 \n2706 6 \n2707 0 \n\n[2708 rows x 3 columns]", - "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
nodeIdfeaturesmodel2Predictions
031336[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...0
11061127[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ...1
21106406[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...2
313195[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...2
437879[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...3
............
27031128975[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...5
27041128977[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...5
27051128978[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...5
2706117328[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...6
270724043[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...0
\n

2708 rows × 3 columns

\n
" - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ + "# And displaying them\n", "predictions" - ], - "metadata": { - "collapsed": false - } + ] }, { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [], + "cell_type": "markdown", "metadata": { "collapsed": false - } + }, + "source": [ + "# And that's it!\n", + "\n", + "Thank you very much for participating in the testing.\n", + "We hope you enjoyed it.\n", + "If you've run the notebook for the first time, now's the time to experiment and changing graph, training parameters, etc.\n", + "If you're feeling like you're done, please reach back to the Google Document and fill in our feedback form.\n", + "\n", + "Thank you!" + ] } ], "metadata": { From fb179327685117507dfa02b002548fece8a3bd26 Mon Sep 17 00:00:00 2001 From: Jacob Sznajdman Date: Mon, 7 Aug 2023 15:21:09 +0200 Subject: [PATCH 15/20] Implement watch_logs and expose job_id Co-Authored-By: Mats Rydberg --- graphdatascience/gnn/gnn_nc_runner.py | 55 +++++++++++++++++++-------- 1 file changed, 40 insertions(+), 15 deletions(-) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index 27aec8d63..9305797cd 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -1,5 +1,6 @@ import json from typing import Any, List +import time from ..error.illegal_attr_checker import IllegalAttrChecker from ..error.uncallable_namespace import UncallableNamespace @@ -7,8 +8,13 @@ class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker): def make_graph_sage_config(self, graph_sage_config): - GRAPH_SAGE_DEFAULT_CONFIG = {"layer_config": {}, "num_neighbors": [25, 10], "dropout": 0.5, - "hidden_channels": 256, "learning_rate": 0.003} + GRAPH_SAGE_DEFAULT_CONFIG = { + "layer_config": {}, + "num_neighbors": [25, 10], + "dropout": 0.5, + "hidden_channels": 256, + "learning_rate": 0.003, + } final_sage_config = GRAPH_SAGE_DEFAULT_CONFIG if graph_sage_config: bad_keys = [] @@ -21,6 +27,21 @@ def make_graph_sage_config(self, graph_sage_config): final_sage_config.update(graph_sage_config) return final_sage_config + def watch_logs(self, job_id: str, logging_interval: int = 5): + def get_logs(offset) -> "Series[Any]": # noqa: F821 + return self._query_runner.run_query( + "RETURN gds.remoteml.getLogs($job_id, $offset)", params={"job_id": job_id, "offset": offset} + ).squeeze() + + received_logs = 0 + training_done = False + while not training_done: + time.sleep(logging_interval) + for log in get_logs(offset=received_logs): + print(log) + received_logs += 1 + return job_id + def train( self, graph_name: str, @@ -30,15 +51,15 @@ def train( relationship_types: List[str], target_node_label: str = None, node_labels: List[str] = None, - graph_sage_config = None - ) -> "Series[Any]": # noqa: F821 + graph_sage_config=None, + ) -> str: mlConfigMap = { "featureProperties": feature_properties, "targetProperty": target_property, "job_type": "train", "nodeProperties": feature_properties + [target_property], "relationshipTypes": relationship_types, - "graph_sage_config": self.make_graph_sage_config(graph_sage_config) + "graph_sage_config": self.make_graph_sage_config(graph_sage_config), } if target_node_label: @@ -49,12 +70,15 @@ def train( mlTrainingConfig = json.dumps(mlConfigMap) # token and uri will be injected by arrow_query_runner - self._query_runner.run_query( - "CALL gds.upload.graph($config)", + job_id = self._query_runner.run_query( + "CALL gds.upload.graph($config) YIELD jobId", params={ "config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name}, }, - ) + ).jobId[0] + + print(f"Started job with jobId={job_id}. Use `gds.gnn.nodeClassification.watch_logs` to track progress.") + return job_id def predict( self, @@ -62,18 +86,19 @@ def predict( model_name: str, mutateProperty: str, predictedProbabilityProperty: str = None, + logging_interval=5, ) -> "Series[Any]": # noqa: F821 - mlConfigMap = { - "job_type": "predict", - "mutateProperty": mutateProperty - } + mlConfigMap = {"job_type": "predict", "mutateProperty": mutateProperty} if predictedProbabilityProperty: mlConfigMap["predictedProbabilityProperty"] = predictedProbabilityProperty mlTrainingConfig = json.dumps(mlConfigMap) - self._query_runner.run_query( - "CALL gds.upload.graph($config)", + job_id = self._query_runner.run_query( + "CALL gds.upload.graph($config) YIELD jobId", params={ "config": {"mlTrainingConfig": mlTrainingConfig, "graphName": graph_name, "modelName": model_name}, }, - ) # type: ignore + ).jobId[0] + + print(f"Started job with jobId={job_id}. Use `gds.gnn.nodeClassification.watch_logs` to track progress.") + return job_id From e76707d461fa90caf09237c65120273ee6c56d0b Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Tue, 8 Aug 2023 17:08:17 +0200 Subject: [PATCH 16/20] Declare correct return type --- graphdatascience/gnn/gnn_nc_runner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index 9305797cd..ac8053d80 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -86,8 +86,7 @@ def predict( model_name: str, mutateProperty: str, predictedProbabilityProperty: str = None, - logging_interval=5, - ) -> "Series[Any]": # noqa: F821 + ) -> str: mlConfigMap = {"job_type": "predict", "mutateProperty": mutateProperty} if predictedProbabilityProperty: mlConfigMap["predictedProbabilityProperty"] = predictedProbabilityProperty From ce06fc7a0043467d88869505cb16b3a9d190d5fe Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Tue, 8 Aug 2023 17:08:46 +0200 Subject: [PATCH 17/20] Print helpful messages when watching logs --- graphdatascience/gnn/gnn_nc_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index ac8053d80..14864efcf 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -1,6 +1,6 @@ import json -from typing import Any, List import time +from typing import Any, List from ..error.illegal_attr_checker import IllegalAttrChecker from ..error.uncallable_namespace import UncallableNamespace @@ -28,6 +28,9 @@ def make_graph_sage_config(self, graph_sage_config): return final_sage_config def watch_logs(self, job_id: str, logging_interval: int = 5): + print(f"Watching logs of job {job_id}.") + print("This needs to be interrupted manually in order to continue (for example when training is done).") + def get_logs(offset) -> "Series[Any]": # noqa: F821 return self._query_runner.run_query( "RETURN gds.remoteml.getLogs($job_id, $offset)", params={"job_id": job_id, "offset": offset} From d898adbd0633fd9e6224bb45256d13d8bf643d87 Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Tue, 8 Aug 2023 17:09:17 +0200 Subject: [PATCH 18/20] Extend notebook to include logging feedback feature Co-authored-by: Jacob Sznajdman Co-authored-by: Brian Shi --- examples/python-runtime.ipynb | 68 ++++++++++++++++++++++++++++------- 1 file changed, 55 insertions(+), 13 deletions(-) diff --git a/examples/python-runtime.ipynb b/examples/python-runtime.ipynb index 952bb7f5e..534bb651b 100644 --- a/examples/python-runtime.ipynb +++ b/examples/python-runtime.ipynb @@ -71,19 +71,41 @@ "It happens asynchronously, so it will return immediately (unless there's an unexpected error 😱).\n", "Of course, the training does not complete instantly, so you will have to wait for it to finish.\n", "\n", - "TODO: instructions for inspecting the log\n", + "## Observing the training progress\n", + "\n", + "You can observe the training progress by watching the logs.\n", + "This is done in the subsequent cell.\n", + "The watching doesn't automatically stop, so you will have to stop it manually.\n", + "Once you see the message 'Training Done', you can interrupt the cell and continue.\n", + "\n", + "## Graph and training parameters\n", + "\n", + "\n", + "\n", + "\n", + "| Parameter | Default | Type | Description |\n", + "|--------------------|----------------|----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| graph_name | - | str | The name of the graph to train on. |\n", + "| model_name | - | str | The name of the model. Must be unique per database and username combination. Models cannot be cleaned up at this time. |\n", + "| feature_properties | - | List[str] | The node properties to use as model features. |\n", + "| target_property | - | str | The node property that contains the target class values. |\n", + "| node_labels | None | List[str] | The node labels to use for training. By default, all labels are used. |\n", + "| relationship_types | None | List[str] | The relationship types to use for training. By default, all types are used. |\n", + "| target_node_label | None | str | Indicates the nodes used for training. Only nodes with this label need to have the `target_property` defined. Other nodes are used for context. By default, all nodes are considered. |\n", + "| graph_sage_config | None | dict | Configuration for the GraphSAGE training. See below. |\n", + "\n", "\n", "## GraphSAGE parameters\n", "\n", "We have exposed several parameters of the PyG GraphSAGE model.\n", "\n", - "| Parameter | Default | Description |\n", - "|-----------------|----------|-------------|\n", - "| layer_config | {} | ??? |\n", - "| num_neighbors | [25, 10] | ??? |\n", - "| dropout | 0.5 | ??? |\n", - "| hidden_channels | 256 | ??? |\n", - "| learning_rate | 0.003 | ??? |\n", + "| Parameter | Default | Description |\n", + "|-----------------|----------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| layer_config | {} | Configuration of the GraphSAGE layers. It supports `aggr`, `normalize`, `root_weight`, `project`, `bias` from [this link](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.SAGEConv.html). Additionally, you can provide message passing configuration from [this link](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.MessagePassing.html#torch_geometric.nn.conv.MessagePassing). |\n", + "| num_neighbors | [25, 10] | Sample sizes for each layer. The length of this list is the number of layers used. All numbers must be >0. |\n", + "| dropout | 0.5 | Probability of dropping out neurons during training. Must be between 0 and 1. |\n", + "| hidden_channels | 256 | The dimension of each hidden layer. Higher value means more expensive training, but higher level of representation. Must be >0. |\n", + "| learning_rate | 0.003 | The learning rate. Must be >0. |\n", "\n", "Please try to use any of them with any useful values.\n" ] @@ -95,11 +117,21 @@ "outputs": [], "source": [ "# Let's train!\n", - "train_response = gds.gnn.nodeClassification.train(\n", + "job_id = gds.gnn.nodeClassification.train(\n", " \"cora\", \"myModel\", [\"features\"], \"subject\", [\"CITES\"], target_node_label=\"Paper\", node_labels=[\"Paper\"]\n", ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# And let's follow the progress by watching the logs\n", + "gds.gnn.nodeClassification.watch_logs(job_id)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -134,8 +166,7 @@ "In this case, we will use it to predict the subject of papers in the Cora dataset.\n", "\n", "Again, this call is asynchronous, so it will return immediately.\n", - "\n", - "TODO: instructions for inspecting the log\n", + "Observe the progress by watching the logs.\n", "\n", "Once the prediction is completed, the predicted classes are added to GDS Graph Catalog (as per normal).\n", "We can retrieve the prediction result (the predictions themselves) by streaming from the graph.\n" @@ -148,7 +179,17 @@ "outputs": [], "source": [ "# Let's trigger prediction!\n", - "predict_result = gds.gnn.nodeClassification.predict(\"cora\", \"myModel\", \"myPredictions\")" + "job_id = gds.gnn.nodeClassification.predict(\"cora\", \"myModel\", \"myPredictions\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# And let's follow progress by watching the logs\n", + "gds.gnn.nodeClassification.watch_logs(job_id)" ] }, { @@ -157,7 +198,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Let's get a graph object\n", + "# Now that prediction is done, let's see the predictions\n", "cora = gds.graph.get(\"cora\")" ] }, @@ -194,6 +235,7 @@ "Thank you very much for participating in the testing.\n", "We hope you enjoyed it.\n", "If you've run the notebook for the first time, now's the time to experiment and changing graph, training parameters, etc.\n", + "For example, try out a heterogeneous graph problem? Or whether performance can be improved by changing some parameter? Run training jobs in parallel, on multiple databases?\n", "If you're feeling like you're done, please reach back to the Google Document and fill in our feedback form.\n", "\n", "Thank you!" From 30bbcf36f4b3419829232c60021b6fe93e201616 Mon Sep 17 00:00:00 2001 From: Mats Rydberg Date: Tue, 8 Aug 2023 17:10:05 +0200 Subject: [PATCH 19/20] Rename notebook to V1 --- examples/{python-runtime.ipynb => python-runtime-V1.ipynb} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/{python-runtime.ipynb => python-runtime-V1.ipynb} (100%) diff --git a/examples/python-runtime.ipynb b/examples/python-runtime-V1.ipynb similarity index 100% rename from examples/python-runtime.ipynb rename to examples/python-runtime-V1.ipynb From 68358d1a108f4ee96674c33f60b4cfcfeab5c052 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florentin=20D=C3=B6rre?= Date: Thu, 10 Aug 2023 11:40:43 +0200 Subject: [PATCH 20/20] Assert correct version is used Co-authored-by: Mats Rydberg --- examples/python-runtime-V1.ipynb | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/python-runtime-V1.ipynb b/examples/python-runtime-V1.ipynb index 534bb651b..5912ad396 100644 --- a/examples/python-runtime-V1.ipynb +++ b/examples/python-runtime-V1.ipynb @@ -6,8 +6,11 @@ "metadata": {}, "outputs": [], "source": [ + "from graphdatascience import GraphDataScience\n", + "from graphdatascience import __version__\n", + "\n", "# Make sure you have installed the custom GDS Client distributed with this notebook\n", - "from graphdatascience import GraphDataScience" + "assert __version__ == \"1.8a1.dev1\"" ] }, {