Skip to content

Commit

Permalink
Add scripts to fix association duplicates
Browse files Browse the repository at this point in the history
  • Loading branch information
jdavcs committed Sep 6, 2024
1 parent c1a9264 commit 9aa19e5
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 20 deletions.
174 changes: 161 additions & 13 deletions lib/galaxy/model/scripts/association_table_fixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,73 +15,221 @@
)

from galaxy.model import (
Base,
GroupRoleAssociation,
UserGroupAssociation,
UserRoleAssociation,
)
from galaxy.model.orm.scripts import get_config


def build_engine():
config = get_config(sys.argv, use_argparse=False, cwd=os.getcwd())
return create_engine(config["db_url"])


class AssociationNullFix(ABC):

def __init__(self, association_model: Base):
self.assoc_model = association_model
self.assoc_name = association_model.__tablename__
def __init__(self):
self.engine = build_engine()
self.assoc_model = self.association_model()
self.assoc_name = self.assoc_model.__tablename__
self.where_clause = self.build_where_clause()

def run(self):
config = get_config(sys.argv, use_argparse=False, cwd=os.getcwd())
engine = create_engine(config["db_url"])

invalid_assocs = self.count_associations_with_nulls(engine)
invalid_assocs = self.count_associations_with_nulls()
if not invalid_assocs:
print(f"Your database does not contain invalid {self.assoc_name} records")
return

print(f"Your database contains {invalid_assocs} invalid {self.assoc_name} records")
answer = input(f'Delete {invalid_assocs} invalid records? (type "yes" to confirm)\n')
if answer.lower() == "yes":
self.delete_associations_with_nulls(engine)
self.delete_associations_with_nulls()
else:
print("Operation aborted")

def count_associations_with_nulls(self, engine):
def count_associations_with_nulls(
self,
):
"""
Retrieve association records where one or both associated item ids are null.
"""
select_stmt = select(func.count()).where(self.where_clause)
with engine.connect() as conn:
with self.engine.connect() as conn:
return conn.scalar(select_stmt)

def delete_associations_with_nulls(self, engine):
def delete_associations_with_nulls(self):
"""
Delete association records where one or both associated item ids are null.
"""
delete_stmt = delete(self.assoc_model).where(self.where_clause)
with engine.begin() as conn:
with self.engine.begin() as conn:
result = conn.execute(delete_stmt)
conn.commit()
print(f"{result.rowcount} invalid records have been deleted")

@abstractmethod
def association_model(self):
"""Return model class"""

@abstractmethod
def build_where_clause(self):
"""Build where clause for filtering records containing nulls instead of associated item ids"""


class UserGroupAssociationNullFix(AssociationNullFix):

def association_model(self):
return UserGroupAssociation

def build_where_clause(self):
return or_(UserGroupAssociation.user_id == null(), UserGroupAssociation.group_id == null())


class UserRoleAssociationNullFix(AssociationNullFix):

def association_model(self):
return UserRoleAssociation

def build_where_clause(self):
return or_(UserRoleAssociation.user_id == null(), UserRoleAssociation.role_id == null())


class GroupRoleAssociationNullFix(AssociationNullFix):

def association_model(self):
return GroupRoleAssociation

def build_where_clause(self):
return or_(GroupRoleAssociation.group_id == null(), GroupRoleAssociation.role_id == null())


class AssociationDuplicateFix(ABC):

def __init__(self):
self.engine = build_engine()
self.assoc_model = self.association_model()
self.assoc_name = self.assoc_model.__tablename__

def run(self):
duplicate_assocs = self.select_duplicate_associations()
if not duplicate_assocs:
print(f"Your database does not contain duplicate {self.assoc_name} records")
return

print(f"Your database contains {len(duplicate_assocs)} groups of duplicate {self.assoc_name} records")
answer = input(
'Delete duplicate records? For each group of duplicates the oldest record will be retained (type "yes" to confirm)\n'
)
if answer.lower() == "yes":
self.delete_duplicate_associations(duplicate_assocs)
else:
print("Operation aborted")

def select_duplicate_associations(self):
"""Retrieve duplicate association records."""
select_stmt = self.build_duplicate_tuples_statement()
with self.engine.connect() as conn:
return conn.execute(select_stmt).all()

@abstractmethod
def association_model(self):
"""Return model class"""

@abstractmethod
def build_duplicate_tuples_statement(self):
"""
Build select statement returning a list of tuples (item1_id, item2_id) that have counts > 1
"""

@abstractmethod
def build_duplicate_ids_statement(self, user_id, group_id):
"""
Build select statement returning a list of ids for duplicate records retrieved via build_duplicate_tuples_statement().
"""

def delete_duplicate_associations(self, records):
"""
Delete duplicate association records retaining oldest record in each group of duplicates.
"""
to_delete = []
with self.engine.begin() as conn:
for item1_id, item2_id in records:
to_delete += self._get_duplicates_to_delete(conn, item1_id, item2_id)
for id in to_delete:
delete_stmt = delete(self.assoc_model).where(self.assoc_model.id == id)
conn.execute(delete_stmt)
conn.commit()
print(f"{len(to_delete)} duplicate records have been deleted")

def _get_duplicates_to_delete(self, connection, item1_id, item2_id):
stmt = self.build_duplicate_ids_statement(item1_id, item2_id)
duplicates = connection.scalars(stmt).all()
# IMPORTANT: we slice to skip the first item ([1:]), which is the oldest record and SHOULD NOT BE DELETED.
return duplicates[1:]


class UserGroupAssociationDuplicateFix(AssociationDuplicateFix):

def association_model(self):
return UserGroupAssociation

def build_duplicate_tuples_statement(self):
stmt = (
select(UserGroupAssociation.user_id, UserGroupAssociation.group_id)
.group_by(UserGroupAssociation.user_id, UserGroupAssociation.group_id)
.having(func.count() > 1)
)
return stmt

def build_duplicate_ids_statement(self, user_id, group_id):
stmt = (
select(UserGroupAssociation.id)
.where(UserGroupAssociation.user_id == user_id, UserGroupAssociation.group_id == group_id)
.order_by(UserGroupAssociation.update_time)
)
return stmt


class UserRoleAssociationDuplicateFix(AssociationDuplicateFix):

def association_model(self):
return UserRoleAssociation

def build_duplicate_tuples_statement(self):
stmt = (
select(UserRoleAssociation.user_id, UserRoleAssociation.role_id)
.group_by(UserRoleAssociation.user_id, UserRoleAssociation.role_id)
.having(func.count() > 1)
)
return stmt

def build_duplicate_ids_statement(self, user_id, role_id):
stmt = (
select(UserRoleAssociation.id)
.where(UserRoleAssociation.user_id == user_id, UserRoleAssociation.role_id == role_id)
.order_by(UserRoleAssociation.update_time)
)
return stmt


class GroupRoleAssociationDuplicateFix(AssociationDuplicateFix):

def association_model(self):
return GroupRoleAssociation

def build_duplicate_tuples_statement(self):
stmt = (
select(GroupRoleAssociation.group_id, GroupRoleAssociation.role_id)
.group_by(GroupRoleAssociation.group_id, GroupRoleAssociation.role_id)
.having(func.count() > 1)
)
return stmt

def build_duplicate_ids_statement(self, group_id, role_id):
stmt = (
select(GroupRoleAssociation.id)
.where(GroupRoleAssociation.group_id == group_id, GroupRoleAssociation.role_id == role_id)
.order_by(GroupRoleAssociation.update_time)
)
return stmt
Empty file modified scripts/db/fix_group_role_association_duplicates.py
100644 → 100755
Empty file.
3 changes: 1 addition & 2 deletions scripts/db/fix_group_role_association_nulls.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

sys.path.insert(1, os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, "lib")))

from galaxy.model import GroupRoleAssociation
from galaxy.model.scripts.association_table_fixer import GroupRoleAssociationNullFix

if __name__ == "__main__":
assoc_fix = GroupRoleAssociationNullFix(GroupRoleAssociation)
assoc_fix = GroupRoleAssociationNullFix()
assoc_fix.run()
11 changes: 10 additions & 1 deletion scripts/db/fix_user_group_association_duplicates.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1 +1,10 @@
# TODO
import os
import sys

sys.path.insert(1, os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, "lib")))

from galaxy.model.scripts.association_table_fixer import UserGroupAssociationDuplicateFix

if __name__ == "__main__":
assoc_fix = UserGroupAssociationDuplicateFix()
assoc_fix.run()
3 changes: 1 addition & 2 deletions scripts/db/fix_user_group_association_nulls.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

sys.path.insert(1, os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, "lib")))

from galaxy.model import UserGroupAssociation
from galaxy.model.scripts.association_table_fixer import UserGroupAssociationNullFix

if __name__ == "__main__":
assoc_fix = UserGroupAssociationNullFix(UserGroupAssociation)
assoc_fix = UserGroupAssociationNullFix()
assoc_fix.run()
Empty file modified scripts/db/fix_user_role_association_duplicates.py
100644 → 100755
Empty file.
3 changes: 1 addition & 2 deletions scripts/db/fix_user_role_association_nulls.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

sys.path.insert(1, os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, "lib")))

from galaxy.model import UserRoleAssociation
from galaxy.model.scripts.association_table_fixer import UserRoleAssociationNullFix

if __name__ == "__main__":
assoc_fix = UserRoleAssociationNullFix(UserRoleAssociation)
assoc_fix = UserRoleAssociationNullFix()
assoc_fix.run()

0 comments on commit 9aa19e5

Please sign in to comment.