44import time
55from typing import Any , Dict , Optional
66
7- import pandas as pd
7+ import pyarrow
88import requests
99from pandas import DataFrame , Series
1010
@@ -32,12 +32,13 @@ def __init__(
3232 self ._namespace = namespace
3333 self ._server_version = server_version
3434 self ._compute_cluster_web_uri = f"http://{ compute_cluster_ip } :5005"
35+ self ._compute_cluster_arrow_uri = f"grpc://{ compute_cluster_ip } :8815"
3536 self ._compute_cluster_mlflow_uri = f"http://{ compute_cluster_ip } :8080"
3637 self ._encrypted_db_password = encrypted_db_password
3738 self ._arrow_uri = arrow_uri
3839
3940 @property
40- def model (self ):
41+ def model (self ) -> "KgeRunner" :
4142 return self
4243
4344 # @compatible_with("stream", min_inclusive=ServerVersion(2, 5, 0))
@@ -75,7 +76,7 @@ def train(
7576 mlflow_experiment_name : Optional [str ] = None ,
7677 ) -> Series :
7778 if epochs_per_checkpoint is None :
78- epochs_per_checkpoint = max (num_epochs / 10 , 1 )
79+ epochs_per_checkpoint = max (int ( num_epochs / 10 ) , 1 )
7980 if loss_function_kwargs is None :
8081 loss_function_kwargs = dict (margin = 1.0 , adversarial_temperature = 1.0 , gamma = 20.0 )
8182 if lr_scheduler_kwargs is None :
@@ -92,7 +93,7 @@ def train(
9293 }
9394 print (algo_config )
9495
95- graph_config = {"name" : G .name ()}
96+ graph_config = {"name" : G .name (), "config_type" : "GdsGraphConfig" }
9697
9798 config = {
9899 "user_name" : "DUMMY_USER" ,
@@ -144,8 +145,10 @@ def predict(
144145 "user_name" : "DUMMY_USER" ,
145146 "task" : "KGE_PREDICT_PYG" ,
146147 "task_config" : {
148+ "graph_config" : {"config_type" : "GdsGraphConfig" , "name" : "NOGRAPH" },
147149 "modelname" : model_name ,
148150 "task_config" : algo_config ,
151+ "stream_rel_results" : True ,
149152 },
150153 "graph_arrow_uri" : self ._arrow_uri ,
151154 }
@@ -162,7 +165,7 @@ def predict(
162165
163166 self ._wait_for_job (job_id )
164167
165- return self ._stream_results (config [ "user_name" ], config [ "task_config" ][ "modelname" ] , job_id )
168+ return self ._stream_results (config , job_id )
166169
167170 @client_only_endpoint ("gds.kge.model" )
168171 def score_triplets (
@@ -180,8 +183,10 @@ def score_triplets(
180183 "user_name" : "DUMMY_USER" ,
181184 "task" : "KGE_SCORE_TRIPLETS_PYG" ,
182185 "task_config" : {
186+ "graph_config" : {"config_type" : "GdsGraphConfig" , "name" : "NOGRAPH" },
183187 "modelname" : model_name ,
184188 "task_config" : algo_config ,
189+ "stream_rel_results" : True ,
185190 },
186191 "graph_arrow_uri" : self ._arrow_uri ,
187192 }
@@ -198,22 +203,20 @@ def score_triplets(
198203
199204 self ._wait_for_job (job_id )
200205
201- return self ._stream_results (config [ "user_name" ], config [ "task_config" ][ "modelname" ] , job_id )
206+ return self ._stream_results (config , job_id )
202207
203- def _stream_results (self , user_name : str , model_name : str , job_id : str ) -> DataFrame :
204- res = requests .get (
205- f"{ self ._compute_cluster_web_uri } /internal/fetch-result" ,
206- params = {"user_name" : user_name , "modelname" : model_name , "job_id" : job_id },
207- )
208- res .raise_for_status ()
208+ def _stream_results (self , config : dict , job_id : str ) -> DataFrame :
209+ client = pyarrow .flight .connect (self ._compute_cluster_arrow_uri )
209210
210- res_file_name = f"res_{ job_id } .json"
211- with open (res_file_name , mode = "wb+" ) as f :
212- f .write (res .content )
211+ if config ["task_config" ].get ("stream_rel_results" , False ):
212+ upload_descriptor = pyarrow .flight .FlightDescriptor .for_path (f"{ job_id } .relationships" )
213+ else :
214+ raise ValueError ("No results to fetch: need to set stream_rel_results or stream_graph_results to True" )
215+ flight = client .get_flight_info (upload_descriptor )
216+ reader = client .do_get (flight .endpoints [0 ].ticket )
217+ read_table = reader .read_all ()
213218
214- df = pd .read_json (res_file_name , orient = "records" , lines = True )
215- os .remove (res_file_name )
216- return df
219+ return read_table .to_pandas ()
217220
218221 def _get_metrics (self , user_name : str , model_name : str , job_id : str ) -> DataFrame :
219222 res = requests .get (
0 commit comments