44import tomllib
55import psycopg2
66
7-
8- def psql (identity : str , sql : str ) -> str :
9- """Call `psql` and execute the given SQL statement."""
10- result = subprocess .run (
11- ["psql" , "-h" , "127.0.0.1" , "-p" , "5432" , "-U" , "postgres" , "-d" , "quickstart" , "--quiet" , "-c" , sql ],
12- encoding = "utf8" ,
13- env = {** os .environ , "PGPASSWORD" : identity },
14- stdout = subprocess .PIPE ,
15- stderr = subprocess .PIPE ,
16- text = True ,
17- )
18-
19- if result .stderr :
20- raise Exception (result .stderr .strip ())
21- return result .stdout .strip ()
22-
23-
24- def connect_db (identity : str ):
25- """Connect to the database using `psycopg2`."""
26- conn = psycopg2 .connect (host = "127.0.0.1" , port = 5432 , user = "postgres" , password = identity , dbname = "quickstart" )
27- conn .set_session (autocommit = True ) # Disable automic transaction
28- return conn
29-
30-
317class SqlFormat (Smoketest ):
328 AUTOPUBLISH = False
339 MODULE_CODE = """
@@ -168,9 +144,32 @@ class SqlFormat(Smoketest):
168144}
169145"""
170146
147+ def psql (self , identity : str , sql : str ) -> str :
148+ """Call `psql` and execute the given SQL statement."""
149+ server = self .get_server_address ()
150+ result = subprocess .run (
151+ ["psql" , "-h" , server ["host" ], "-p" , "5432" , "-U" , "postgres" , "-d" , "quickstart" , "--quiet" , "-c" , sql ],
152+ encoding = "utf8" ,
153+ env = {** os .environ , "PGPASSWORD" : identity },
154+ stdout = subprocess .PIPE ,
155+ stderr = subprocess .PIPE ,
156+ text = True ,
157+ )
158+
159+ if result .stderr :
160+ raise Exception (result .stderr .strip ())
161+ return result .stdout .strip ()
162+
163+ def connect_db (self , identity : str ):
164+ """Connect to the database using `psycopg2`."""
165+ server = self .get_server_address ()
166+ conn = psycopg2 .connect (host = server ["host" ], port = 5432 , user = "postgres" , password = identity , dbname = "quickstart" )
167+ conn .set_session (autocommit = True ) # Disable automic transaction
168+ return conn
169+
171170 def assertSql (self , token : str , sql : str , expected ):
172171 self .maxDiff = None
173- sql_out = psql (token , sql )
172+ sql_out = self . psql (token , sql )
174173 sql_out = "\n " .join ([line .rstrip () for line in sql_out .splitlines ()])
175174 expected = "\n " .join ([line .rstrip () for line in expected .splitlines ()])
176175 print (sql_out )
@@ -242,7 +241,7 @@ def test_sql_conn(self):
242241 self .publish_module ("quickstart" , clear = True )
243242 self .call ("test" )
244243
245- conn = connect_db (token )
244+ conn = self . connect_db (token )
246245 # Check prepared statements (faked by `psycopg2`)
247246 with conn .cursor () as cur :
248247 cur .execute ("select * from t_uints where u8 = %s and u16 = %s" , (105 , 1050 ))
@@ -262,20 +261,20 @@ def test_failures(self):
262261 self .publish_module ("quickstart" , clear = True )
263262
264263 # Empty query
265- sql_out = psql (token , "" )
264+ sql_out = self . psql (token , "" )
266265 self .assertEqual (sql_out , "" )
267266
268267 # Connection fails with invalid token
269268 with self .assertRaises (Exception ) as cm :
270- psql ("invalid_token" , "SELECT * FROM t_uints" )
269+ self . psql ("invalid_token" , "SELECT * FROM t_uints" )
271270 self .assertIn ("Invalid token" , str (cm .exception ))
272271
273272 # Returns error for unsupported `sql` statements
274273 with self .assertRaises (Exception ) as cm :
275- psql (token , "SELECT CASE a WHEN 1 THEN 'one' ELSE 'other' END FROM t_uints" )
274+ self . psql (token , "SELECT CASE a WHEN 1 THEN 'one' ELSE 'other' END FROM t_uints" )
276275 self .assertIn ("Unsupported" , str (cm .exception ))
277276
278277 # And prepared statements
279278 with self .assertRaises (Exception ) as cm :
280- psql (token , "SELECT * FROM t_uints where u8 = $1" )
279+ self . psql (token , "SELECT * FROM t_uints where u8 = $1" )
281280 self .assertIn ("Unsupported" , str (cm .exception ))
0 commit comments