Skip to content

Commit af95bb2

Browse files
authored
[DPE-7302] Prefixes helpers (#27)
* Prefixes helpers * Unit tests * Pass the in result directly
1 parent 6748b7c commit af95bb2

File tree

4 files changed

+187
-10
lines changed

4 files changed

+187
-10
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
[project]
55
name = "postgresql-charms-single-kernel"
66
description = "Shared and reusable code for PostgreSQL-related charms"
7-
version = "16.1.0"
7+
version = "16.1.1"
88
readme = "README.md"
99
license = "Apache-2.0"
1010
authors = [

single_kernel_postgresql/utils/postgresql.py

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ def __init__(self, message: Optional[str] = None):
114114
self.message = message
115115

116116

117+
class PostgreSQLUpdateUserError(PostgreSQLBaseError):
118+
"""Exception raised when creating a user fails."""
119+
120+
117121
class PostgreSQLUndefinedHostError(PostgreSQLBaseError):
118122
"""Exception when host is not set."""
119123

@@ -146,6 +150,10 @@ class PostgreSQLGetPostgreSQLVersionError(PostgreSQLBaseError):
146150
"""Exception raised when retrieving PostgreSQL version fails."""
147151

148152

153+
class PostgreSQLListDatabasesError(PostgreSQLBaseError):
154+
"""Exception raised when retrieving the databases."""
155+
156+
149157
class PostgreSQLListAccessibleDatabasesForUserError(PostgreSQLBaseError):
150158
"""Exception raised when retrieving the accessible databases for a user fails."""
151159

@@ -439,24 +447,36 @@ def _adjust_user_definition(
439447
Returns:
440448
A tuple containing the adjusted user definition and a list of additional statements.
441449
"""
450+
db_roles, connect_statements = self._adjust_user_roles(user, roles, database)
451+
if db_roles:
452+
str_roles = [f'"{role}"' for role in db_roles]
453+
user_definition += f" IN ROLE {', '.join(str_roles)}"
454+
return user_definition, connect_statements
455+
456+
def _adjust_user_roles(
457+
self, user: str, roles: Optional[List[str]], database: Optional[str]
458+
) -> Tuple[List[str], List[str]]:
459+
"""Adjusts the user definition to include additional statements.
460+
461+
Returns:
462+
A tuple containing the adjusted user definition and a list of additional statements.
463+
"""
464+
db_roles = []
442465
connect_statements = []
443466
if database:
444467
if roles is not None and not any(
445-
True
446-
for role in roles
447-
if role in [ROLE_STATS, ROLE_READ, ROLE_DML, ROLE_BACKUP, ROLE_DBA]
468+
role in [ROLE_STATS, ROLE_READ, ROLE_DML, ROLE_BACKUP, ROLE_DBA] for role in roles
448469
):
449-
user_definition += f' IN ROLE "charmed_{database}_admin", "charmed_{database}_dml"'
470+
db_roles.append(f"charmed_{database}_admin")
471+
db_roles.append(f"charmed_{database}_dml")
450472
else:
451473
connect_statements.append(
452474
SQL("GRANT CONNECT ON DATABASE {} TO {};").format(
453475
Identifier(database), Identifier(user)
454476
)
455477
)
456478
if roles is not None and any(
457-
True
458-
for role in roles
459-
if role
479+
role
460480
in [
461481
ROLE_STATS,
462482
ROLE_READ,
@@ -466,14 +486,15 @@ def _adjust_user_definition(
466486
ROLE_ADMIN,
467487
ROLE_DATABASES_OWNER,
468488
]
489+
for role in roles
469490
):
470491
for system_database in ["postgres", "template1"]:
471492
connect_statements.append(
472493
SQL("GRANT CONNECT ON DATABASE {} TO {};").format(
473494
Identifier(system_database), Identifier(user)
474495
)
475496
)
476-
return user_definition, connect_statements
497+
return db_roles, connect_statements
477498

478499
def _process_extra_user_roles(
479500
self, user: str, extra_user_roles: Optional[List[str]] = None
@@ -1841,3 +1862,50 @@ def drop_hba_triggers(self) -> None:
18411862
finally:
18421863
if connection:
18431864
connection.close()
1865+
1866+
def list_databases(self, prefix: Optional[str] = None) -> List[str]:
1867+
"""List non-system databases starting with prefix."""
1868+
prefix_stmt = (
1869+
SQL(" AND datname LIKE {}").format(Literal(prefix + "%")) if prefix else SQL("")
1870+
)
1871+
try:
1872+
with self._connect_to_database() as connection, connection.cursor() as cursor:
1873+
cursor.execute(
1874+
SQL(
1875+
"SELECT datname FROM pg_database WHERE datistemplate = false AND datname <>'postgres'{};"
1876+
).format(prefix_stmt)
1877+
)
1878+
return [row[0] for row in cursor.fetchall()]
1879+
except psycopg2.Error as e:
1880+
raise PostgreSQLListDatabasesError() from e
1881+
finally:
1882+
if connection:
1883+
connection.close()
1884+
1885+
def add_user_to_databases(
1886+
self, user: str, databases: List[str], extra_user_roles: Optional[List[str]] = None
1887+
) -> None:
1888+
"""Grant user access to database."""
1889+
try:
1890+
roles, _ = self._process_extra_user_roles(user, extra_user_roles)
1891+
connect_stmt = []
1892+
for database in databases:
1893+
db_roles, db_connect_stmt = self._adjust_user_roles(user, roles, database)
1894+
roles += db_roles
1895+
connect_stmt += db_connect_stmt
1896+
with self._connect_to_database() as connection, connection.cursor() as cursor:
1897+
cursor.execute(SQL("RESET ROLE;"))
1898+
cursor.execute(SQL("BEGIN;"))
1899+
cursor.execute(SQL("SET LOCAL log_statement = 'none';"))
1900+
cursor.execute(SQL("COMMIT;"))
1901+
1902+
# Add extra user roles to the new user.
1903+
for role in roles:
1904+
cursor.execute(
1905+
SQL("GRANT {} TO {};").format(Identifier(role), Identifier(user))
1906+
)
1907+
for statement in connect_stmt:
1908+
cursor.execute(statement)
1909+
except psycopg2.Error as e:
1910+
logger.error(f"Failed to create user: {e}")
1911+
raise PostgreSQLUpdateUserError() from e

tests/unit/test_postgresql.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
PostgreSQLCreateUserError,
2525
PostgreSQLDatabasesSetupError,
2626
PostgreSQLGetLastArchivedWALError,
27+
PostgreSQLListDatabasesError,
2728
PostgreSQLUndefinedHostError,
2829
PostgreSQLUndefinedPasswordError,
30+
PostgreSQLUpdateUserError,
2931
ROLE_DATABASES_OWNER,
3032
)
3133
from single_kernel_postgresql.config.literals import Substrates
@@ -813,3 +815,110 @@ def test_set_up_database_k8s_skips_change_owner_and_chmod(harness):
813815
# On K8S substrate we must not attempt to change ownership or chmod the path.
814816
_change_owner.assert_not_called()
815817
_chmod.assert_not_called()
818+
819+
820+
def test_list_databases():
821+
with patch(
822+
"single_kernel_postgresql.utils.postgresql.PostgreSQL._connect_to_database",
823+
) as _connect_to_database:
824+
pg = PostgreSQL(
825+
Substrates.VM, "primary", "current", "operator", "password", "postgres", None
826+
)
827+
execute = _connect_to_database.return_value.__enter__.return_value.cursor.return_value.__enter__.return_value.execute
828+
829+
# No prefix
830+
pg.list_databases()
831+
execute.assert_called_once_with(
832+
Composed([
833+
SQL(
834+
"SELECT datname FROM pg_database WHERE datistemplate = false AND datname <>'postgres'"
835+
),
836+
SQL(""),
837+
SQL(";"),
838+
])
839+
)
840+
execute.reset_mock()
841+
842+
# With prefix
843+
pg.list_databases(prefix="test")
844+
execute.assert_called_once_with(
845+
Composed([
846+
SQL(
847+
"SELECT datname FROM pg_database WHERE datistemplate = false AND datname <>'postgres'"
848+
),
849+
Composed([SQL(" AND datname LIKE "), Literal("test%")]),
850+
SQL(";"),
851+
])
852+
)
853+
execute.reset_mock()
854+
855+
# Exception
856+
execute.side_effect = psycopg2.Error
857+
with pytest.raises(PostgreSQLListDatabasesError):
858+
pg.list_databases()
859+
assert False
860+
861+
862+
def test_add_user_to_databases():
863+
with (
864+
patch(
865+
"single_kernel_postgresql.utils.postgresql.PostgreSQL._connect_to_database"
866+
) as _connect_to_database,
867+
patch(
868+
"single_kernel_postgresql.utils.postgresql.PostgreSQL._process_extra_user_roles",
869+
return_value=([], []),
870+
),
871+
):
872+
pg = PostgreSQL(
873+
Substrates.VM, "primary", "current", "operator", "password", "postgres", None
874+
)
875+
execute = _connect_to_database.return_value.__enter__.return_value.cursor.return_value.__enter__.return_value.execute
876+
877+
pg.add_user_to_databases("test-user", ["db1", "db2"])
878+
assert execute.call_count == 8
879+
execute.assert_any_call(SQL("RESET ROLE;"))
880+
execute.assert_any_call(SQL("BEGIN;"))
881+
execute.assert_any_call(SQL("SET LOCAL log_statement = 'none';"))
882+
execute.assert_any_call(SQL("COMMIT;"))
883+
execute.assert_any_call(
884+
Composed([
885+
SQL("GRANT "),
886+
Identifier("charmed_db1_admin"),
887+
SQL(" TO "),
888+
Identifier("test-user"),
889+
SQL(";"),
890+
])
891+
)
892+
execute.assert_any_call(
893+
Composed([
894+
SQL("GRANT "),
895+
Identifier("charmed_db1_dml"),
896+
SQL(" TO "),
897+
Identifier("test-user"),
898+
SQL(";"),
899+
])
900+
)
901+
execute.assert_any_call(
902+
Composed([
903+
SQL("GRANT "),
904+
Identifier("charmed_db2_admin"),
905+
SQL(" TO "),
906+
Identifier("test-user"),
907+
SQL(";"),
908+
])
909+
)
910+
execute.assert_any_call(
911+
Composed([
912+
SQL("GRANT "),
913+
Identifier("charmed_db2_dml"),
914+
SQL(" TO "),
915+
Identifier("test-user"),
916+
SQL(";"),
917+
])
918+
)
919+
920+
# Exception
921+
execute.side_effect = psycopg2.Error
922+
with pytest.raises(PostgreSQLUpdateUserError):
923+
pg.add_user_to_databases("test-user", ["db1", "db2"])
924+
assert False

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)