1
+ from typing import Optional , Callable , Any
2
+ from psycopg2 import connect , extensions
3
+ from multiprocessing import Process , Pipe , connection
4
+ import time
5
+ from select import select
6
+
7
+ POSTGRESQL_CHANNEL_NAME = "casbin_role_watcher"
8
+
9
+
10
+ def casbin_subscription (
11
+ process_conn : connection .PipeConnection ,
12
+ host : str ,
13
+ user : str ,
14
+ password : str ,
15
+ port : Optional [int ] = 5432 ,
16
+ delay : Optional [int ] = 2 ,
17
+ channel_name : Optional [str ] = POSTGRESQL_CHANNEL_NAME ,
18
+ ):
19
+ # delay connecting to postgresql (postgresql connection failure)
20
+ time .sleep (delay )
21
+ conn = connect (host = host , port = port , user = user , password = password )
22
+ # Can only receive notifications when not in transaction, set this for easier usage
23
+ conn .set_isolation_level (extensions .ISOLATION_LEVEL_AUTOCOMMIT )
24
+ curs = conn .cursor ()
25
+ curs .execute (f"LISTEN { channel_name } ;" )
26
+ print ("Waiting for casbin policy update" )
27
+ while True and not curs .closed :
28
+ if not select ([conn ], [], [], 5 ) == ([], [], []):
29
+ print ("Casbin policy update identified.." )
30
+ conn .poll ()
31
+ while conn .notifies :
32
+ notify = conn .notifies .pop (0 )
33
+ print (f"Notify: { notify .payload } " )
34
+ process_conn .put (notify .payload )
35
+
36
+
37
+
38
+ class PostgresqlWatcher (object ):
39
+ def __init__ (
40
+ self ,
41
+ host : str ,
42
+ user : str ,
43
+ password : str ,
44
+ port : Optional [int ] = 5432 ,
45
+ channel_name : Optional [str ] = POSTGRESQL_CHANNEL_NAME ,
46
+ start_process : Optional [bool ] = True ,
47
+ ):
48
+ self .host = host
49
+ self .port = port
50
+ self .user = user
51
+ self .password = password
52
+ self .channel_name = channel_name
53
+ self .subscribed_process , self .parent_conn = self .create_subscriber_process (
54
+ start_process
55
+ )
56
+
57
+ def create_subscriber_process (
58
+ self ,
59
+ start_process : Optional [bool ] = True ,
60
+ delay : Optional [int ] = 2 ,
61
+ ):
62
+ parent_conn , child_conn = Pipe ()
63
+ p = Process (
64
+ target = casbin_subscription ,
65
+ args = (
66
+ child_conn ,
67
+ self .host ,
68
+ self .user ,
69
+ self .password ,
70
+ self .port ,
71
+ delay ,
72
+ self .channel_name ,
73
+ ),
74
+ daemon = True ,
75
+ )
76
+ if start_process :
77
+ p .start ()
78
+ self .should_reload ()
79
+ return p , parent_conn
80
+
81
+ def update_callback (self ):
82
+ print ("callback called because casbin role updated" )
83
+
84
+ def set_update_callback (self , fn_name : Any ):
85
+ print ("runtime is set update callback" ,fn_name )
86
+ self .update_callback = fn_name
87
+
88
+
89
+ def update (self ):
90
+ conn = connect (
91
+ host = self .host ,
92
+ port = self .port ,
93
+ user = self .user ,
94
+ password = self .password ,
95
+ )
96
+ # Can only receive notifications when not in transaction, set this for easier usage
97
+ conn .set_isolation_level (extensions .ISOLATION_LEVEL_AUTOCOMMIT )
98
+ curs = conn .cursor ()
99
+ curs .execute (
100
+ f"NOTIFY { self .channel_name } ,'casbin policy update at { time .time ()} '"
101
+ )
102
+ conn .close ()
103
+ return True
104
+
105
+ def should_reload (self ):
106
+ try :
107
+ if self .parent_conn .poll ():
108
+ message = self .parent_conn .recv ()
109
+ print (f"message:{ message } " )
110
+ return True
111
+ except EOFError :
112
+ print (
113
+ "Child casbin-watcher subscribe process has stopped, "
114
+ "attempting to recreate the process in 10 seconds..."
115
+ )
116
+ self .subscribed_process , self .parent_conn = self .create_subscriber_process (
117
+ delay = 10
118
+ )
119
+ return False
0 commit comments