@@ -15,15 +15,18 @@ class SqlServerContainer(DbContainer):
1515 >>> import sqlalchemy
1616 >>> from testcontainers.mssql import SqlServerContainer
1717
18- >>> with SqlServerContainer() as mssql:
19- ... engine = sqlalchemy.create_engine(mssql.get_connection_url())
20- ... with engine.begin() as connection:
21- ... result = connection.execute(sqlalchemy.text("select @@VERSION"))
18+ >>> with SqlServerContainer() as mssql:
19+ ... engine = sqlalchemy.create_engine(mssql.get_connection_url())
20+ ... result = engine.execute(sqlalchemy.text("select @@VERSION"))
21+ Notes
22+ -----
23+ Requires `ODBC Driver 17 for SQL Server <https://docs.microsoft.com/en-us/sql/connect/odbc/
24+ linux-mac/installing-the-microsoft-odbc-driver-for-sql-server>`_.
2225 """
2326
2427 def __init__ (self , image : str = "mcr.microsoft.com/mssql/server:2019-latest" ,
2528 username : str = "SA" , password : Optional [str ] = None , port : int = 1433 ,
26- dbname : str = "tempdb" , dialect : str = 'mssql+pymssql' , ** kwargs ) -> None :
29+ dbname : str = "tempdb" , dialect : str = 'mssql+pymssql' , driver : str = "ODBC Driver 17 for SQL Server" , ** kwargs ) -> None :
2730 raise_for_deprecated_parameter (kwargs , "user" , "username" )
2831 super (SqlServerContainer , self ).__init__ (image , ** kwargs )
2932
@@ -34,6 +37,7 @@ def __init__(self, image: str = "mcr.microsoft.com/mssql/server:2019-latest",
3437 self .username = username
3538 self .dbname = dbname
3639 self .dialect = dialect
40+ self .driver = driver
3741
3842 def _configure (self ) -> None :
3943 self .with_env ("SA_PASSWORD" , self .password )
@@ -42,7 +46,10 @@ def _configure(self) -> None:
4246 self .with_env ("ACCEPT_EULA" , 'Y' )
4347
4448 def get_connection_url (self ) -> str :
45- return super ()._create_connection_url (
49+ base_url = super (SqlServerContainer , self )._create_connection_url (
4650 dialect = self .dialect , username = self .username , password = self .password ,
47- dbname = self .dbname , port = self .port
51+ db_name = self .dbname , port = self .port
4852 )
53+ url = base_url + f"?driver={ '+' .join (self .driver .split (' ' ))} "
54+ return url
55+
0 commit comments