Skip to content

Commit 8382db4

Browse files
authored
fix: fixed should_reload behaviour, close PostgreSQL connections, block until PostgresqlWatcher is ready, refactorings (#29)
* chore: updated dev requirements * chore: format code with black * chore: updated .gitignore * fix: type hint, multiprocessing.Pipe is a Callable and not a type * fix: make Watcher.should_reload return value consistent * fix: Handle Connection and Process objects consistenly and close them before creating new ones * feat: Customize the postgres channel name * chore: Some code reorg - Make PostgresqlWatcher.create_subscription_process a private method - Rename casbin_subscription to _casbin_channel_subscription * docs: added doc string for PostgresqlWatcher.update * refactor: PostgresqlWatcher.set_update_callback * refactor!: Rename 'start_process' flag to 'start_listening' * docs: Added doc string to PostgresqlWatcher.__init__ * fix: Added proper destructor for PostgresqlWatcher * chore: fix type hints and proper handling of the channel_name argument and its default value * test: fix tests decrease select timeout to one second in child Process remove infinite timout in PostgresqlWatcher.should_reload create a new watcher instance for every test case * feat: Setup logging module for unit tests * fix: typo * feat: channel subscription with proper resource cleanup Moved channel subscription function to separate file and added context manager for the connection, that handles SIGINT, SIGTERM for proper resource cleanup * chore: removed unnecessary tests * feat: Wait for Process to be ready to receive messages from PostgreSQL * test: multiple instances of the watcher * test: make sure every test case uses its own channel * test: no update * refactor: moved code into with block * feat: automaticall call the update handler if it is provided * refactor: sorted imports * docs: updated README * refactor: improved readibility * refactor: resolve a potential infinite loop with a custom Exception * refactor: make timeout configurable by the user * fix: docs * fix: ensure type hint compatibility with Python 3.9 * feat: make sure multiple calls of update() get resolved by one call of should_reload() thanks to @pradeepranwa1
1 parent 4b808d0 commit 8382db4

8 files changed

+361
-118
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,4 @@ dmypy.json
130130

131131
.idea/
132132
*.iml
133+
.vscode

README.md

+36-8
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pip install casbin-postgresql-watcher
1616
```
1717

1818
## Basic Usage Example
19-
### With Flask-authz
19+
2020
```python
2121
from flask_authz import CasbinEnforcer
2222
from postgresql_watcher import PostgresqlWatcher
@@ -25,23 +25,51 @@ from casbin.persist.adapters import FileAdapter
2525

2626
casbin_enforcer = CasbinEnforcer(app, adapter)
2727
watcher = PostgresqlWatcher(host=HOST, port=PORT, user=USER, password=PASSWORD, dbname=DBNAME)
28-
watcher.set_update_callback(casbin_enforcer.e.load_policy)
28+
watcher.set_update_callback(casbin_enforcer.load_policy)
2929
casbin_enforcer.set_watcher(watcher)
30-
```
3130

32-
## Basic Usage Example With SSL Enabled
31+
# Call should_reload before every call of enforce to make sure
32+
# the policy is update to date
33+
watcher.should_reload()
34+
if casbin_enforcer.enforce("alice", "data1", "read"):
35+
# permit alice to read data1
36+
pass
37+
else:
38+
# deny the request, show an error
39+
pass
40+
```
3341

34-
See [PostgresQL documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS) for full details of SSL parameters.
42+
alternatively, if you need more control
3543

36-
### With Flask-authz
3744
```python
3845
from flask_authz import CasbinEnforcer
3946
from postgresql_watcher import PostgresqlWatcher
4047
from flask import Flask
4148
from casbin.persist.adapters import FileAdapter
4249

4350
casbin_enforcer = CasbinEnforcer(app, adapter)
44-
watcher = PostgresqlWatcher(host=HOST, port=PORT, user=USER, password=PASSWORD, dbname=DBNAME, sslmode="verify_full", sslcert=SSLCERT, sslrootcert=SSLROOTCERT, sslkey=SSLKEY)
45-
watcher.set_update_callback(casbin_enforcer.e.load_policy)
51+
watcher = PostgresqlWatcher(host=HOST, port=PORT, user=USER, password=PASSWORD, dbname=DBNAME)
4652
casbin_enforcer.set_watcher(watcher)
53+
54+
# Call should_reload before every call of enforce to make sure
55+
# the policy is update to date
56+
if watcher.should_reload():
57+
casbin_enforcer.load_policy()
58+
59+
if casbin_enforcer.enforce("alice", "data1", "read"):
60+
# permit alice to read data1
61+
pass
62+
else:
63+
# deny the request, show an error
64+
pass
65+
```
66+
67+
## Basic Usage Example With SSL Enabled
68+
69+
See [PostgresQL documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS) for full details of SSL parameters.
70+
71+
```python
72+
...
73+
watcher = PostgresqlWatcher(host=HOST, port=PORT, user=USER, password=PASSWORD, dbname=DBNAME, sslmode="verify_full", sslcert=SSLCERT, sslrootcert=SSLROOTCERT, sslkey=SSLKEY)
74+
...
4775
```

dev_requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
black==20.8b1
1+
black==24.4.2

postgresql_watcher/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .watcher import PostgresqlWatcher
1+
from .watcher import PostgresqlWatcher, PostgresqlWatcherChannelSubscriptionTimeoutError
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from enum import IntEnum
2+
from logging import Logger
3+
from multiprocessing.connection import Connection
4+
from select import select
5+
from signal import signal, SIGINT, SIGTERM
6+
from time import sleep
7+
from typing import Optional
8+
9+
from psycopg2 import connect, extensions, InterfaceError
10+
11+
12+
CASBIN_CHANNEL_SELECT_TIMEOUT = 1 # seconds
13+
14+
15+
def casbin_channel_subscription(
16+
process_conn: Connection,
17+
logger: Logger,
18+
host: str,
19+
user: str,
20+
password: str,
21+
channel_name: str,
22+
port: int = 5432,
23+
dbname: str = "postgres",
24+
delay: int = 2,
25+
sslmode: Optional[str] = None,
26+
sslrootcert: Optional[str] = None,
27+
sslcert: Optional[str] = None,
28+
sslkey: Optional[str] = None,
29+
):
30+
# delay connecting to postgresql (postgresql connection failure)
31+
sleep(delay)
32+
db_connection = connect(
33+
host=host,
34+
port=port,
35+
user=user,
36+
password=password,
37+
dbname=dbname,
38+
sslmode=sslmode,
39+
sslrootcert=sslrootcert,
40+
sslcert=sslcert,
41+
sslkey=sslkey,
42+
)
43+
# Can only receive notifications when not in transaction, set this for easier usage
44+
db_connection.set_isolation_level(extensions.ISOLATION_LEVEL_AUTOCOMMIT)
45+
db_cursor = db_connection.cursor()
46+
context_manager = _ConnectionManager(db_connection, db_cursor)
47+
48+
with context_manager:
49+
db_cursor.execute(f"LISTEN {channel_name};")
50+
logger.debug("Waiting for casbin policy update")
51+
process_conn.send(_ChannelSubscriptionMessage.IS_READY)
52+
53+
while not db_cursor.closed:
54+
try:
55+
select_result = select(
56+
[db_connection],
57+
[],
58+
[],
59+
CASBIN_CHANNEL_SELECT_TIMEOUT,
60+
)
61+
if select_result != ([], [], []):
62+
logger.debug("Casbin policy update identified")
63+
db_connection.poll()
64+
while db_connection.notifies:
65+
notify = db_connection.notifies.pop(0)
66+
logger.debug(f"Notify: {notify.payload}")
67+
process_conn.send(_ChannelSubscriptionMessage.RECEIVED_UPDATE)
68+
except (InterfaceError, OSError) as e:
69+
# Log an exception if these errors occurred without the context beeing closed
70+
if not context_manager.connections_were_closed:
71+
logger.critical(e, exc_info=True)
72+
break
73+
74+
75+
class _ChannelSubscriptionMessage(IntEnum):
76+
IS_READY = 1
77+
RECEIVED_UPDATE = 2
78+
79+
80+
class _ConnectionManager:
81+
"""
82+
You can not use 'with' and a connection / cursor directly in this setup.
83+
For more details see this issue: https://github.com/psycopg/psycopg2/issues/941#issuecomment-864025101.
84+
As a workaround this connection manager / context manager class is used, that also handles SIGINT and SIGTERM and
85+
closes the database connection.
86+
"""
87+
88+
def __init__(self, connection, cursor) -> None:
89+
self.connection = connection
90+
self.cursor = cursor
91+
self.connections_were_closed = False
92+
93+
def __enter__(self):
94+
signal(SIGINT, self._close_connections)
95+
signal(SIGTERM, self._close_connections)
96+
return self
97+
98+
def _close_connections(self, *_):
99+
if self.cursor is not None:
100+
self.cursor.close()
101+
self.cursor = None
102+
if self.connection is not None:
103+
self.connection.close()
104+
self.connection = None
105+
self.connections_were_closed = True
106+
107+
def __exit__(self, *_):
108+
self._close_connections()

0 commit comments

Comments
 (0)