Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions graphdatascience/procedure_surface/api/catalog_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from graphdatascience import Graph
from graphdatascience.procedure_surface.api.base_result import BaseResult
from graphdatascience.procedure_surface.api.graph_sampling_endpoints import GraphSamplingEndpoints


class CatalogEndpoints(ABC):
Expand Down Expand Up @@ -65,6 +66,11 @@ def filter(
"""
pass

@property
@abstractmethod
def sample(self) -> GraphSamplingEndpoints:
pass


class GraphListResult(BaseResult):
graph_name: str
Expand Down
170 changes: 170 additions & 0 deletions graphdatascience/procedure_surface/api/graph_sampling_endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, List, Optional

from graphdatascience import Graph
from graphdatascience.procedure_surface.api.base_result import BaseResult


class GraphSamplingEndpoints(ABC):
"""
Abstract base class defining the API for graph sampling operations.
"""

@abstractmethod
def rwr(
self,
G: Graph,
graph_name: str,
start_nodes: Optional[List[int]] = None,
restart_probability: Optional[float] = None,
sampling_ratio: Optional[float] = None,
node_label_stratification: Optional[bool] = None,
relationship_weight_property: Optional[str] = None,
relationship_types: Optional[List[str]] = None,
node_labels: Optional[List[str]] = None,
sudo: Optional[bool] = None,
log_progress: Optional[bool] = None,
username: Optional[str] = None,
concurrency: Optional[Any] = None,
job_id: Optional[Any] = None,
) -> GraphSamplingResult:
"""
Random walk with restarts (RWR) samples the graph by taking random walks from a set of start nodes.

On each step of a random walk, there is a probability that the walk stops, and a new walk from one of the start
nodes starts instead (i.e. the walk restarts). Each node visited on these walks will be part of the sampled
subgraph. The resulting subgraph is stored as a new graph in the Graph Catalog.

Parameters
----------
G : Graph
The input graph to be sampled.
graph_name : str
The name of the new graph that is stored in the graph catalog.
start_nodes : list of int, optional
IDs of the initial set of nodes in the original graph from which the sampling random walks will start.
By default, a single node is chosen uniformly at random.
restart_probability : float, optional
The probability that a sampling random walk restarts from one of the start nodes.
Default is 0.1.
sampling_ratio : float, optional
The fraction of nodes in the original graph to be sampled.
Default is 0.15.
node_label_stratification : bool, optional
If true, preserves the node label distribution of the original graph.
Default is False.
relationship_weight_property : str, optional
Name of the relationship property to use as weights. If unspecified, the algorithm runs unweighted.
relationship_types : Optional[List[str]], default=None
Filter the named graph using the given relationship types. Relationships with any of the given types will be
included.
node_labels : Optional[List[str]], default=None
Filter the named graph using the given node labels. Nodes with any of the given labels will be included.
sudo : bool, optional
Bypass heap control. Use with caution.
Default is False.
log_progress : bool, optional
Turn `on/off` percentage logging while running procedure.
Default is True.
username : str, optional
Use Administrator access to run an algorithm on a graph owned by another user.
Default is None.
concurrency : Any, optional
The number of concurrent threads used for running the algorithm.
Default is 4.
job_id : Any, optional
An ID that can be provided to more easily track the algorithm’s progress.
By default, a random job id is generated.

Returns
-------
GraphSamplingResult
The result of the Random Walk with Restart (RWR), including the dimensions of the sampled graph.
"""
pass

@abstractmethod
def cnarw(
self,
G: Graph,
graph_name: str,
start_nodes: Optional[List[int]] = None,
restart_probability: Optional[float] = None,
sampling_ratio: Optional[float] = None,
node_label_stratification: Optional[bool] = None,
relationship_weight_property: Optional[str] = None,
relationship_types: Optional[List[str]] = None,
node_labels: Optional[List[str]] = None,
sudo: Optional[bool] = None,
log_progress: Optional[bool] = None,
username: Optional[str] = None,
concurrency: Optional[Any] = None,
job_id: Optional[Any] = None,
) -> GraphSamplingResult:
"""
Common Neighbour Aware Random Walk (CNARW) samples the graph by taking random walks from a set of start nodes

CNARW is a graph sampling technique that involves optimizing the selection of the next-hop node. It takes into
account the number of common neighbours between the current node and the next-hop candidates. On each step of a
random walk, there is a probability that the walk stops, and a new walk from one of the start nodes starts
instead (i.e. the walk restarts). Each node visited on these walks will be part of the sampled subgraph. The
resulting subgraph is stored as a new graph in the Graph Catalog.

Parameters
----------
G : Graph
The input graph to be sampled.
graph_name : str
The name of the new graph that is stored in the graph catalog.
start_nodes : list of int, optional
IDs of the initial set of nodes in the original graph from which the sampling random walks will start.
By default, a single node is chosen uniformly at random.
restart_probability : float, optional
The probability that a sampling random walk restarts from one of the start nodes.
Default is 0.1.
sampling_ratio : float, optional
The fraction of nodes in the original graph to be sampled.
Default is 0.15.
node_label_stratification : bool, optional
If true, preserves the node label distribution of the original graph.
Default is False.
relationship_weight_property : str, optional
Name of the relationship property to use as weights. If unspecified, the algorithm runs unweighted.
relationship_types : Optional[List[str]], default=None
Filter the named graph using the given relationship types. Relationships with any of the given types will be
included.
node_labels : Optional[List[str]], default=None
Filter the named graph using the given node labels. Nodes with any of the given labels will be included.
sudo : bool, optional
Bypass heap control. Use with caution.
Default is False.
log_progress : bool, optional
Turn `on/off` percentage logging while running procedure.
Default is True.
username : str, optional
Use Administrator access to run an algorithm on a graph owned by another user.
Default is None.
concurrency : Any, optional
The number of concurrent threads used for running the algorithm.
Default is 4.
job_id : Any, optional
An ID that can be provided to more easily track the algorithm’s progress.
By default, a random job id is generated.

Returns
-------
GraphSamplingResult
The result of the Common Neighbour Aware Random Walk (CNARW), including the dimensions of the sampled graph.
"""
pass


class GraphSamplingResult(BaseResult):
graph_name: str
from_graph_name: str
node_count: int
relationship_count: int
start_node_count: int
project_millis: int
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
GraphFilterResult,
GraphListResult,
)
from graphdatascience.procedure_surface.api.graph_sampling_endpoints import GraphSamplingEndpoints
from graphdatascience.procedure_surface.arrow.graph_sampling_arrow_endpoints import GraphSamplingArrowEndpoints
from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter
from graphdatascience.query_runner.protocol.project_protocols import ProjectProtocol
from graphdatascience.query_runner.termination_flag import TerminationFlag
Expand Down Expand Up @@ -116,6 +118,10 @@ def filter(

return GraphFilterResult(**JobClient.get_summary(self._arrow_client, job_id))

@property
def sample(self) -> GraphSamplingEndpoints:
return GraphSamplingArrowEndpoints(self._arrow_client)

def _arrow_config(self) -> dict[str, Any]:
connection_info = self._arrow_client.advertised_connection_info()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from __future__ import annotations

from typing import Any, List, Optional

from graphdatascience import Graph
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
from graphdatascience.arrow_client.v2.job_client import JobClient
from graphdatascience.procedure_surface.api.graph_sampling_endpoints import (
GraphSamplingEndpoints,
GraphSamplingResult,
)
from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter


class GraphSamplingArrowEndpoints(GraphSamplingEndpoints):
def __init__(self, arrow_client: AuthenticatedArrowClient):
self._arrow_client = arrow_client

def rwr(
self,
G: Graph,
graph_name: str,
start_nodes: Optional[List[int]] = None,
restart_probability: Optional[float] = None,
sampling_ratio: Optional[float] = None,
node_label_stratification: Optional[bool] = None,
relationship_weight_property: Optional[str] = None,
relationship_types: Optional[List[str]] = None,
node_labels: Optional[List[str]] = None,
sudo: Optional[bool] = None,
log_progress: Optional[bool] = None,
username: Optional[str] = None,
concurrency: Optional[Any] = None,
job_id: Optional[Any] = None,
) -> GraphSamplingResult:
config = ConfigConverter.convert_to_gds_config(
from_graph_name=G.name(),
graph_name=graph_name,
startNodes=start_nodes,
restartProbability=restart_probability,
samplingRatio=sampling_ratio,
nodeLabelStratification=node_label_stratification,
relationshipWeightProperty=relationship_weight_property,
relationship_types=relationship_types,
node_labels=node_labels,
sudo=sudo,
log_progress=log_progress,
username=username,
concurrency=concurrency,
job_id=job_id,
)

job_id = JobClient.run_job_and_wait(self._arrow_client, "v2/graph.sample.rwr", config)

return GraphSamplingResult(**JobClient.get_summary(self._arrow_client, job_id))

def cnarw(
self,
G: Graph,
graph_name: str,
start_nodes: Optional[List[int]] = None,
restart_probability: Optional[float] = None,
sampling_ratio: Optional[float] = None,
node_label_stratification: Optional[bool] = None,
relationship_weight_property: Optional[str] = None,
relationship_types: Optional[List[str]] = None,
node_labels: Optional[List[str]] = None,
sudo: Optional[bool] = None,
log_progress: Optional[bool] = None,
username: Optional[str] = None,
concurrency: Optional[Any] = None,
job_id: Optional[Any] = None,
) -> GraphSamplingResult:
config = ConfigConverter.convert_to_gds_config(
from_graph_name=G.name(),
graph_name=graph_name,
startNodes=start_nodes,
restartProbability=restart_probability,
samplingRatio=sampling_ratio,
nodeLabelStratification=node_label_stratification,
relationshipWeightProperty=relationship_weight_property,
relationship_types=relationship_types,
node_labels=node_labels,
sudo=sudo,
log_progress=log_progress,
username=username,
concurrency=concurrency,
job_id=job_id,
)

job_id = JobClient.run_job_and_wait(self._arrow_client, "v2/graph.sample.cnarw", config)

return GraphSamplingResult(**JobClient.get_summary(self._arrow_client, job_id))
Loading