diff --git a/examples/mssqlclient.py b/examples/mssqlclient.py index cca3e9b59..ccb7a789d 100755 --- a/examples/mssqlclient.py +++ b/examples/mssqlclient.py @@ -87,29 +87,51 @@ if options.aesKey is not None: options.k = True - ms_sql = tds.MSSQL(options.target_ip, int(options.port), remoteName) - ms_sql.connect() - try: - if options.k is True: - res = ms_sql.kerberosLogin(options.db, username, password, domain, options.hashes, options.aesKey, - kdcHost=options.dc_ip) + with tds.MSSQL( + address=options.target_ip, + port=int(options.port), + remoteName=remoteName + ) as mssql_instance: + + try: + if options.k is True: + res = mssql_instance.kerberosLogin( + database=options.db, + username=username, + password=password, + domain=domain, + hashes=options.hashes, + aesKey=options.aesKey, + kdcHost=options.dc_ip, + TGT=None, + TGS=None, + useCache=True + ) + else: + res = mssql_instance.login( + database=options.db, + username=username, + password=password, + domain=domain, + hashes=options.hashes, + useWindowsAuth=options.windows_auth + ) + except Exception as exc: + logging.debug("Exception:", exc_info=True) + logging.error(str(exc)) + res = False else: - res = ms_sql.login(options.db, username, password, domain, options.hashes, options.windows_auth) - ms_sql.printReplies() - except Exception as e: - logging.debug("Exception:", exc_info=True) - logging.error(str(e)) - res = False - if res is True: - shell = SQLSHELL(ms_sql, options.show) - if options.file: - for line in options.file.readlines(): - print("SQL> %s" % line, end=' ') - shell.onecmd(line) - elif options.command: - for c in options.command: - print("SQL> %s" % c) - shell.onecmd(c) - else: - shell.cmdloop() - ms_sql.disconnect() + mssql_instance.printReplies() + + if res is True: + shell = SQLSHELL(mssql_instance, options.show) + if options.file: + for line in options.file.readlines(): + print("SQL> %s" % line, end=' ') + shell.onecmd(line) + elif options.command: + for c in options.command: + print("SQL> %s" % c) + shell.onecmd(c) + else: + shell.cmdloop() diff --git a/impacket/tds.py b/impacket/tds.py index da5e8180a..154f297d1 100644 --- a/impacket/tds.py +++ b/impacket/tds.py @@ -1691,6 +1691,9 @@ def batchStatement(self, cmd,tuplemode=False): sql_query = batch def changeDB(self, db): + + + if db != self.currentDB: chdb = 'use %s' % db self.batch(chdb) @@ -1713,4 +1716,19 @@ def RunSQLStatement(self,db,sql_query,wait=True,**kwArgs): self.RunSQLQuery(db,sql_query,wait=wait) if self.lastError: raise self.lastError - return True \ No newline at end of file + return True + + def __enter__(self): + """ + Enter the runtime context related to this object. + Establishes the connection. + Returns the MSSQL instance itself. + """ + self.connect() + return self + + def __exit__(self, exc_type, exc_value, traceback): + """ + Exit the runtime context and ensure the connection is closed. + """ + self.disconnect() \ No newline at end of file