Skip to content

Commit

Permalink
stash
Browse files Browse the repository at this point in the history
  • Loading branch information
bdpedigo committed Mar 20, 2024
1 parent 7669037 commit 2ed2d09
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 12 deletions.
65 changes: 54 additions & 11 deletions networkframe/groupby.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Union
from typing import Callable, Literal, Union


class NodeGroupBy:
"""A class for grouping a `NetworkFrame` by a set of labels."""

def __init__(self, frame, source_groupby, target_groupby, induced: bool = False):
def __init__(
self, frame, source_groupby, target_groupby, by, induced: bool = False
):
"""Groupby on nodes.
Parameters
Expand All @@ -21,6 +24,7 @@ def __init__(self, frame, source_groupby, target_groupby, induced: bool = False)
self._source_groupby = source_groupby
self._target_groupby = target_groupby
self.induced = induced
self.by = by

self._axis: Union[str, int]
if source_groupby is None:
Expand Down Expand Up @@ -64,7 +68,7 @@ def __iter__(self):
for target_group, target_objects in self._target_groupby:
if self.induced and source_group != target_group:
continue
else:
else:
yield (
(source_group, target_group),
self._frame.loc[source_objects.index, target_objects.index],
Expand All @@ -76,23 +80,62 @@ def __iter__(self):
for target_group, target_objects in self._target_groupby:
yield target_group, self._frame.loc[:, target_objects.index]

# def apply(self, func, *args, data=False, **kwargs):
def apply_nodes(self, func):
pass

def apply_edges(self, func, columns=None):
by = self.by
if isinstance(by, list):
raise ValueError(
"Currently can only apply edges to a single group in `by`."
)
if self._axis != "both":
raise ValueError("Currently can only apply edges when groupby is 'both'.")

if isinstance(func, str):
if func == "size":
func = lambda x: x.shape[0]
elif func == "mean":
func = lambda x: x.mean()
elif func == "sum":
func = lambda x: x.sum()
elif func == "max":
func = lambda x: x.max()
elif func == "min":
func = lambda x: x.min()
elif func == "any":
func = lambda x: x.any()

edges = self._frame.apply_node_features(by, inplace=False).edges

edge_by = [f"source_{by}", f"target_{by}"]
if columns is not None:
out = edges.groupby(edge_by)[columns].apply(func)
else:
out = edges.groupby(edge_by).apply(func)
return out

def size_edges(self):
return self.apply_edges("size")

# def apply(self, func, to="frame"):
# """Apply a function to each group."""
# if self._axis == 'both':
# if self._axis == "both":
# answer = pd.DataFrame(
# index=self.source_group_names, columns=self.target_group_names
# )

# else:
# if self._axis == 0:
# answer = pd.Series(index=self.source_group_names)
# else:
# answer = pd.Series(index=self.target_group_names)
# for group, frame in self:
# if data:
# value = func(frame.data, *args, **kwargs)
# else:
# value = func(frame, *args, **kwargs)
# for group, frame in tqdm(self, total=len(self)):
# if to == "frame":
# value = func(frame)
# elif to == "nodes":
# value = func(frame.nodes)
# elif to == "edges":
# value = func(frame.edges)
# answer.at[group] = value
# return answer

Expand Down
29 changes: 28 additions & 1 deletion networkframe/networkframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,7 @@ def groupby_nodes(
raise ValueError("Axis must be 0 or 1 or 'both'")

return NodeGroupBy(
self, source_nodes_groupby, target_nodes_groupby, induced=induced
self, source_nodes_groupby, target_nodes_groupby, induced=induced, by=by
)

@property
Expand Down Expand Up @@ -1217,6 +1217,33 @@ def k_hop_neighborhood(
select_indices
return self.query_nodes("index in @select_indices", local_dict=locals())

def condense(
self,
by: Union[Any, list],
func: Union[
Callable, Literal["mean", "sum", "max", "min", "any", "size"]
] = "size",
weight_name="weight",
columns=None,
) -> "NetworkFrame":
"""Apply a function, and create a new NetworkFrame such that the nodes of the
new frame are the groups and the edges are the result of the function.
The API and implementation of this function is rather fragile and subject to
change.
"""

edges = self.groupby_nodes(by).apply_edges(func, columns=columns)
edges.name = weight_name
edges = edges.reset_index()
edges = edges.rename(
columns={f"source_{by}": "source", f"target_{by}": "target"}
)
nodes_index = pd.Index(self.nodes[by].unique())
nodes_index.name = by
nodes = pd.DataFrame(index=nodes_index)
return self.__class__(nodes, edges, directed=self.directed)


class LocIndexer:
"""A class for indexing a NetworkFrame using .loc."""
Expand Down

0 comments on commit 2ed2d09

Please sign in to comment.