Skip to content

Commit

Permalink
new return scheme
Browse files Browse the repository at this point in the history
  • Loading branch information
bdpedigo committed Feb 1, 2024
1 parent 494bb23 commit ebd30d7
Showing 1 changed file with 61 additions and 30 deletions.
91 changes: 61 additions & 30 deletions networkframe/networkframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
sources: Optional[pd.Index] = None,
targets: Optional[pd.Index] = None,
validate: bool = True,
induced: bool = True,
):
"""
Parameters
Expand All @@ -56,6 +57,10 @@ def __init__(
validate
Whether to check that the nodes and edges are valid. This can be turned off
to speed performance but risks causing errors later on.
induced
Whether the network is induced, i.e. whether the nodes and edges are
specified as a subgraph of a larger network. Currently non-functional,
subject to some API changes in the future.
"""

# TODO more checks ensuring that nodes and edges are valid.
Expand Down Expand Up @@ -87,20 +92,21 @@ def __init__(
# TODO some checks on repeated edges if not directed
self.directed = directed

def copy(self) -> "NetworkFrame":
"""Return a copy of the NetworkFrame.
Returns
-------
:
A copy of the NetworkFrame.
"""
return copy.deepcopy(self)
def _return(self, inplace: bool = False, **kwargs):
if inplace:
for k, v in kwargs.items():
setattr(self, k, v)
return None
else:
out = copy.copy(self)
for k, v in kwargs.items():
setattr(out, k, v)
return out

def _return(
def _old_return(
self, nodes: pd.DataFrame, edges: pd.DataFrame, inplace: bool
) -> Optional[Self]:
"""Return a view of the NetworkFrame.
"""Return a view/shallow copy of the NetworkFrame.
This is used internally to return a view of the NetworkFrame rather than
a copy.
Expand All @@ -122,7 +128,15 @@ def _return(
self.edges = edges
return None
else:
return self.__class__(nodes, edges, directed=self.directed, validate=False)
# kwargs = self.get_public_attributes()
# kwargs["nodes"] = nodes
# kwargs["edges"] = edges
# kwargs["validate"] = False
# return self.__class__(**kwargs)
out = copy.copy(self)
out.nodes = nodes
out.edges = edges
return out

@property
def sources(self) -> pd.Index:
Expand Down Expand Up @@ -194,7 +208,7 @@ def reindex_nodes(self, index: pd.Index) -> Self:
"""
nodes = self.nodes.reindex(index=index, axis=0)
edges = self.edges.query("(source in @nodes.index) & (target in @nodes.index)")
out: NetworkFrame = self._return(nodes, edges, inplace=False) # type: ignore
out: NetworkFrame = self._return(nodes=nodes, edges=edges, inplace=False) # type: ignore
return out # type: ignore

def remove_nodes(
Expand All @@ -218,7 +232,7 @@ def remove_nodes(
nodes = self.nodes.drop(index=nodes)
# get the edges that are connected to the nodes that are left after the query
edges = self.edges.query("(source in @nodes.index) & (target in @nodes.index)")
return self._return(nodes, edges, inplace=inplace)
return self._return(nodes=nodes, edges=edges, inplace=inplace)

def remove_edges(self, remove_edges: pd.DataFrame, inplace=False) -> Optional[Self]:
# """Remove edges from the network."""
Expand All @@ -233,19 +247,19 @@ def remove_edges(self, remove_edges: pd.DataFrame, inplace=False) -> Optional[Se
# TODO i think this destroys the old index?
edges = self.edges.set_index(["source", "target"]).loc[new_index].reset_index()

return self._return(self.nodes, edges, inplace=inplace)
return self._return(edges=edges, inplace=inplace)

def add_nodes(self, new_nodes: pd.DataFrame, inplace=False) -> Optional[Self]:
# """Add nodes to the network."""
nodes = pd.concat([self.nodes, new_nodes], copy=False, sort=False, axis=0)

return self._return(nodes, self.edges, inplace=inplace)
return self._return(nodes=nodes, inplace=inplace)

def add_edges(self, new_edges: pd.DataFrame, inplace=False) -> Optional[Self]:
# """Add edges to the network."""
edges = pd.concat([self.edges, new_edges], copy=False, sort=False, axis=0)

return self._return(self.nodes, edges, inplace=inplace)
return self._return(edges=edges, inplace=inplace)

def query_nodes(
self,
Expand Down Expand Up @@ -312,7 +326,14 @@ def query_nodes(
"(source in @nodes.index) & (target in @nodes.index)", **kwargs
)

return self._return(nodes, edges, inplace=inplace)
return self._return(nodes=nodes, edges=edges, inplace=inplace)

def get_public_attributes(self):
return {
key: value
for key, value in self.__dict__.items()
if not key.startswith("_")
}

def query_edges(
self,
Expand Down Expand Up @@ -376,7 +397,7 @@ def query_edges(
expr, local_dict=local_dict, global_dict=global_dict, **kwargs
)

return self._return(self.nodes, edges, inplace=inplace)
return self._return(edges=edges, inplace=inplace)

def remove_unused_nodes(self, inplace: bool = False) -> Optional[Self]:
"""
Expand Down Expand Up @@ -423,7 +444,7 @@ def remove_unused_nodes(self, inplace: bool = False) -> Optional[Self]:
new_index = source_index.union(target_index)
nodes = self.nodes.loc[new_index]

return self._return(nodes, self.edges, inplace=inplace)
return self._return(nodes=nodes, inplace=inplace)

def apply_node_features(
self, columns: ColumnsType, axis: EdgeAxisType = "both", inplace: bool = False
Expand Down Expand Up @@ -465,7 +486,7 @@ def apply_node_features(
if axis in ["target", "both"]:
for col in columns:
edges[f"target_{col}"] = self.edges["target"].map(self.nodes[col])
return self._return(self.nodes, edges, inplace=inplace)
return self._return(edges=edges, inplace=inplace)

def to_adjacency(
self, weight_col: str = "weight", aggfunc: Union[str, Callable] = "sum"
Expand Down Expand Up @@ -674,7 +695,7 @@ def largest_connected_component(
nodes = self.nodes.iloc[mask]
edges = self.edges.query("(source in @nodes.index) & (target in @nodes.index)")

return self._return(nodes, edges, inplace=inplace)
return self._return(nodes=nodes, edges=edges, inplace=inplace)

def connected_components(self) -> Iterator["NetworkFrame"]:
"""
Expand Down Expand Up @@ -748,7 +769,20 @@ def label_nodes_by_component(
nodes[name] = labels
nodes[name] = nodes[name].astype(int)

return self._return(nodes, self.edges, inplace=inplace)
return self._return(nodes=nodes, inplace=inplace)

def component_labels(self) -> pd.Series:
"""Return the indices of the connected components.
Returns
-------
:
A series of the same length as the number of nodes, where each element
corresponds to the connected component of the node at that index.
"""
_, labels = self._get_component_indices()
labels = pd.Series(labels, index=self.nodes.index)
return labels

def select_component_from_node(
self, node_id: Any, directed=True, inplace=False
Expand Down Expand Up @@ -784,13 +818,10 @@ def select_component_from_node(

dists = shortest_path(sparse_adjacency, directed=directed, indices=node_iloc)
mask = ~np.isinf(dists)
out = self.loc[mask, mask]
if inplace:
self.nodes = out.nodes
self.edges = out.edges
return None
else:
return out
subindex = self.nodes.index[mask]
return self.query_nodes(
"index in @subindex", inplace=inplace, local_dict=locals()
)

def groupby_nodes(
self, by: Union[Any, list], axis: EdgeAxisType = "both", **kwargs
Expand Down

0 comments on commit ebd30d7

Please sign in to comment.