Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 29 additions & 23 deletions graph_split/split_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,38 +356,44 @@ def split_cv(df, split_type, n_folds, seed=None):
verify_split(df, train_idx[i], test_idx[i], split_type)
return train_idx, test_idx

def generate_negative_samples(df, graph_type='directed', anchor='source', seed=None):
def generate_negative_samples(edges, random_samples, duplicates=True, seed=None):
'''
For any positive edge **(a, b)**, create a negative edge **(a, c)** such that (a, c) was not present in the set of positive edges.
Parameters:
graph_type: 'undirected', 'directed'
anchor: 'source', 'target', 'both'

If graph_type=='directed' and anchor=='source':
- without edge_type: For any positive edge (a, b) create a negative edge (a, c) such that (a, c) was not present in the set of positive edges.
edges: Pandas DataFrame containing two columns where each row is the edge (a, b).
random_samples: Pandas series containing list of possible labels for c. None if c should be sampled from the set of b.
duplicates: True if the edge (a, c) should be able to appear more than once.
seed: Random generation seed for reproducibility of samples.
:return: DataFrame containing negative edges (a, c)
'''

#works for directed, anchor based graph without edgetype.
if df.shape[1]>2:
exit('Error: Not implemented for extra information on edges except for source and target. ')
randomState = np.random.RandomState(seed)
if edges.shape[1]>2:
raise ValueError('Too many columns! Ensure edges only contain columns for an edge (a, b) and no other information.')
df = edges.set_axis(['source', 'target'], axis=1, inplace=False)
init_sample_space = []
if (random_samples is None):
init_sample_space = set(df['target'].unique())
else:
init_sample_space = set(random_samples.unique())

source_wise_targets = df.groupby('source')['target'].agg([('target_list', lambda x:set(x)), ('count', 'size')]).reset_index()
source_wise_targets['target_list'] = source_wise_targets['target_list'].apply(lambda x: sorted(init_sample_space.difference(x)))

all_sampled_sources = []
all_sampled_targets = []

if (graph_type=='directed') and (anchor=='source'):
init_sample_space = set(df['target'].unique())

source_wise_targets = df.groupby('source').agg(target_list= ('target', lambda x:set(x)), count = ('target', 'size')).reset_index()
source_wise_targets['target_list'] = source_wise_targets['target_list'].apply(lambda x: sorted(init_sample_space.difference(x)))

for i, row in source_wise_targets.iterrows():
for i, row in source_wise_targets.iterrows():
sample_count = 0
if (duplicates == False):
sample_count = min(row['count'], len(row['target_list']))
all_sampled_targets.extend(list(random.Random(seed).sample(row['target_list'],sample_count)))
all_sampled_sources.extend([row['source']]*sample_count)
negative_df = pd.DataFrame({'source': all_sampled_sources, 'target': all_sampled_targets})

return negative_df
else:
exit('Error: current code only works for graph_type= directed, anchor=source. ')
all_sampled_targets.extend(randomState.choice(row['target_list'], size=sample_count, replace=False))
else:
sample_count = row['count']
all_sampled_targets.extend(randomState.choice(row['target_list'], size=sample_count, replace=True))
all_sampled_sources.extend([row['source']]*sample_count)
negative_df = pd.DataFrame({'source': all_sampled_sources, 'target': all_sampled_targets})
return negative_df



Expand Down
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@ build-backend = "setuptools.build_meta"

[project]
name = "graph_split"
version = "0.3.1"
version = "0.3.2"
description = "A package to split edges of graphs using different criteria compatible with machine learning model training."
readme = "README.md"
license = "GPL-3.0-only"
authors = [
{ name = "Nure Tasnina", email = "[email protected]" }
]
requires-python = ">=3.9"
dependencies = [
"pandas",
"numpy",
"scikit-learn"
]
requires-python = ">=3.7"

[project.urls]
Homepage = "https://github.com/Murali-group/graph-split"