diff --git a/pyTigerGraph/__init__.py b/pyTigerGraph/__init__.py index 0a46f960..f916bb34 100644 --- a/pyTigerGraph/__init__.py +++ b/pyTigerGraph/__init__.py @@ -2,6 +2,6 @@ from pyTigerGraph.pytgasync.pyTigerGraph import AsyncTigerGraphConnection from pyTigerGraph.common.exception import TigerGraphException -__version__ = "1.9.0" +__version__ = "1.9.1" __license__ = "Apache 2" diff --git a/pyTigerGraph/common/auth.py b/pyTigerGraph/common/auth.py index fca1d50c..0b0079ef 100644 --- a/pyTigerGraph/common/auth.py +++ b/pyTigerGraph/common/auth.py @@ -83,7 +83,7 @@ def _prep_token_request(restppUrl: str, else: method = "POST" url = gsUrl + "/gsql/v1/tokens" # used for TG 4.x - data = {"graph": graphname} + data = {"graph": graphname} if graphname else {} # alt_url and alt_data used to construct the method and url for functions run in TG version 3.x alt_url = restppUrl+"/requesttoken" # used for TG 3.x diff --git a/pyTigerGraph/common/base.py b/pyTigerGraph/common/base.py index f90383c4..1dd95578 100644 --- a/pyTigerGraph/common/base.py +++ b/pyTigerGraph/common/base.py @@ -33,7 +33,7 @@ def excepthook(type, value, traceback): class PyTigerGraphCore(object): - def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", + def __init__(self, host: str = "http://127.0.0.1", graphname: str = "", gsqlSecret: str = "", username: str = "tigergraph", password: str = "tigergraph", tgCloud: bool = False, restppPort: Union[int, str] = "9000", gsPort: Union[int, str] = "14240", gsqlVersion: str = "", version: str = "", @@ -110,7 +110,8 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", self.base64_credential = base64.b64encode( "{0}:{1}".format(self.username, self.password).encode("utf-8")).decode("utf-8") - self.authHeader = self._set_auth_header() + # Detect auth mode automatically by checking if jwtToken or apiToken is provided + self.authHeader, self.authMode = self._set_auth_header() # TODO Eliminate version and use gsqlVersion only, meaning TigerGraph server version if gsqlVersion: @@ -179,7 +180,7 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", self.restppPort = restppPort self.restppUrl = self.host + ":" + self.restppPort - self.gsPort = "" + self.gsPort = gsPort if self.tgCloud and (gsPort == "14240" or gsPort == "443"): self.gsPort = sslPort self.gsUrl = self.host + ":" + sslPort @@ -216,11 +217,11 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", def _set_auth_header(self): """Set the authentication header based on available tokens or credentials.""" if self.jwtToken: - return {"Authorization": "Bearer " + self.jwtToken} + return {"Authorization": "Bearer " + self.jwtToken}, "token" elif self.apiToken: - return {"Authorization": "Bearer " + self.apiToken} + return {"Authorization": "Bearer " + self.apiToken}, "token" else: - return {"Authorization": "Basic {0}".format(self.base64_credential)} + return {"Authorization": "Basic {0}".format(self.base64_credential)}, "pwd" def _verify_jwt_token_support(self): try: @@ -275,7 +276,7 @@ def _error_check(self, res: dict) -> bool: ) return False - def _prep_req(self, authMode, headers, url, method, data): + def _prep_req(self, headers, url, method, data): logger.info("entry: _req") if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) @@ -283,31 +284,24 @@ def _prep_req(self, authMode, headers, url, method, data): _headers = {} # If JWT token is provided, always use jwtToken as token - if authMode == "token": - if isinstance(self.jwtToken, str) and self.jwtToken.strip() != "": - token = self.jwtToken - elif isinstance(self.apiToken, tuple): - token = self.apiToken[0] - elif isinstance(self.apiToken, str) and self.apiToken.strip() != "": - token = self.apiToken - else: - token = None + if isinstance(self.jwtToken, str) and self.jwtToken.strip() != "": + token = self.jwtToken + elif isinstance(self.apiToken, tuple): + token = self.apiToken[0] + elif isinstance(self.apiToken, str) and self.apiToken.strip() != "": + token = self.apiToken + else: + token = None - if token: - self.authHeader = {'Authorization': "Bearer " + token} - _headers = self.authHeader - else: - self.authHeader = { - 'Authorization': 'Basic {0}'.format(self.base64_credential)} - _headers = self.authHeader - authMode = 'pwd' - - if authMode == "pwd": - if self.jwtToken: - _headers = {'Authorization': "Bearer " + self.jwtToken} - else: - _headers = {'Authorization': 'Basic {0}'.format( - self.base64_credential)} + if token: + self.authHeader = {'Authorization': "Bearer " + token} + _headers = self.authHeader + self.authMode = "token" + else: + self.authHeader = { + 'Authorization': 'Basic {0}'.format(self.base64_credential)} + _headers = self.authHeader + self.authMode = 'pwd' if headers: _headers.update(headers) diff --git a/pyTigerGraph/pyTigerGraph.py b/pyTigerGraph/pyTigerGraph.py index c6583be7..0da90d9d 100644 --- a/pyTigerGraph/pyTigerGraph.py +++ b/pyTigerGraph/pyTigerGraph.py @@ -26,7 +26,7 @@ class TigerGraphConnection(pyTigerGraphVertex, pyTigerGraphEdge, pyTigerGraphUDT pyTigerGraphLoading, pyTigerGraphPath, pyTigerGraphDataset, object): """Python wrapper for TigerGraph's REST++ and GSQL APIs""" - def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", + def __init__(self, host: str = "http://127.0.0.1", graphname: str = "", gsqlSecret: str = "", username: str = "tigergraph", password: str = "tigergraph", tgCloud: bool = False, restppPort: Union[int, str] = "9000", gsPort: Union[int, str] = "14240", gsqlVersion: str = "", version: str = "", diff --git a/pyTigerGraph/pyTigerGraphAuth.py b/pyTigerGraph/pyTigerGraphAuth.py index 3396bec6..fa896f01 100644 --- a/pyTigerGraph/pyTigerGraphAuth.py +++ b/pyTigerGraph/pyTigerGraphAuth.py @@ -239,6 +239,7 @@ def getToken(self, ) self.apiToken = token self.authHeader = auth_header + self.authMode = "token" logger.info("exit: getToken") return token diff --git a/pyTigerGraph/pyTigerGraphBase.py b/pyTigerGraph/pyTigerGraphBase.py index 624c082f..7bcbe3a6 100644 --- a/pyTigerGraph/pyTigerGraphBase.py +++ b/pyTigerGraph/pyTigerGraphBase.py @@ -13,7 +13,7 @@ conn = TigerGraphConnection( host="http://localhost", - graphname="MyGraph", + graphname="your_graph_name", username="tigergraph", password="tigergraph") @@ -64,7 +64,7 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", protocol (http:// or https://). If `certPath` is `None` and the protocol is https, a self-signed certificate will be used. graphname: - The default graph for running queries. + The graph name for running queries. **Required** - must be specified. gsqlSecret: The secret key for GSQL. See https://docs.tigergraph.com/tigergraph-server/current/user-access/managing-credentials#_secrets. username: @@ -102,144 +102,17 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", TigerGraphException: In case on invalid URL scheme. """ - logger.info("entry: __init__") - if logger.level == logging.DEBUG: - logger.debug("params: " + self._locals(locals())) + super().__init__(host=host, graphname=graphname, gsqlSecret=gsqlSecret, + username=username, password=password, tgCloud=tgCloud, + restppPort=restppPort, gsPort=gsPort, gsqlVersion=gsqlVersion, + version=version, apiToken=apiToken, useCert=useCert, certPath=certPath, + debug=debug, sslPort=sslPort, gcp=gcp, jwtToken=jwtToken) - inputHost = urlparse(host) - if inputHost.scheme not in ["http", "https"]: - raise TigerGraphException("Invalid URL scheme. Supported schemes are http and https.", - "E-0003") - self.netloc = inputHost.netloc - self.host = "{0}://{1}".format(inputHost.scheme, self.netloc) - if gsqlSecret != "": - self.username = "__GSQL__secret" - self.password = gsqlSecret - else: - self.username = username - self.password = password - self.graphname = graphname - self.responseConfigHeader = {} - self.awsIamHeaders = {} - - self.jwtToken = jwtToken - self.apiToken = apiToken - self.base64_credential = base64.b64encode( - "{0}:{1}".format(self.username, self.password).encode("utf-8")).decode("utf-8") - - self.authHeader = self._set_auth_header() - - # TODO Eliminate version and use gsqlVersion only, meaning TigerGraph server version - if gsqlVersion: - self.version = gsqlVersion - elif version: + if graphname == "MyGraph": warnings.warn( - "The `version` parameter is deprecated; use the `gsqlVersion` parameter instead.", - DeprecationWarning) - self.version = version - else: - self.version = "" - - if debug is not None: - warnings.warn( - "The `debug` parameter is deprecated; configure standard logging in your app.", - DeprecationWarning) - if not debug: - sys.excepthook = excepthook # TODO Why was this necessary? Can it be removed? - sys.tracebacklimit = None - - self.schema = None - - # TODO Remove useCert parameter - if useCert is not None: - warnings.warn( - "The `useCert` parameter is deprecated; the need for a CA certificate is now determined by URL scheme.", - DeprecationWarning) - if inputHost.scheme == "http": - self.downloadCert = False - self.useCert = False - self.certPath = "" - elif inputHost.scheme == "https": - if not certPath: - self.downloadCert = True - else: - self.downloadCert = False - self.useCert = True - self.certPath = certPath - self.sslPort = str(sslPort) - - # TODO Remove gcp parameter - if gcp: - warnings.warn("The `gcp` parameter is deprecated.", - DeprecationWarning) - self.tgCloud = tgCloud or gcp - if "tgcloud" in self.netloc.lower(): - try: # If get request succeeds, using TG Cloud instance provisioned after 6/20/2022 - self._get(self.host + "/api/ping", resKey="message") - self.tgCloud = True - # If get request fails, using TG Cloud instance provisioned before 6/20/2022, before new firewall config - except requests.exceptions.RequestException: - self.tgCloud = False - except TigerGraphException: - raise (TigerGraphException("Incorrect graphname.")) - - restppPort = str(restppPort) - gsPort = str(gsPort) - sslPort = str(sslPort) - if restppPort == gsPort: - self.restppPort = restppPort - self.restppUrl = self.host + ":" + restppPort + "/restpp" - elif (self.tgCloud and (restppPort == "9000" or restppPort == "443")): - if restppPort == gsPort: - sslPort = gsPort - self.restppPort = sslPort - self.restppUrl = self.host + ":" + sslPort + "/restpp" - else: - self.restppPort = restppPort - self.restppUrl = self.host + ":" + self.restppPort - - self.gsPort = gsPort - if self.tgCloud and (gsPort == "14240" or gsPort == "443"): - self.gsPort = sslPort - self.gsUrl = self.host + ":" + sslPort - else: - self.gsPort = gsPort - self.gsUrl = self.host + ":" + self.gsPort - self.url = "" - - if self.username.startswith("arn:aws:iam::"): - import boto3 - from botocore.awsrequest import AWSRequest - from botocore.auth import SigV4Auth - # Prepare a GetCallerIdentity request. - request = AWSRequest( - method="POST", - url="https://sts.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15", - headers={ - 'Host': 'sts.amazonaws.com' - }) - # Get headers - SigV4Auth(boto3.Session().get_credentials(), - "sts", "us-east-1").add_auth(request) - self.awsIamHeaders["X-Amz-Date"] = request.headers["X-Amz-Date"] - self.awsIamHeaders["X-Amz-Security-Token"] = request.headers["X-Amz-Security-Token"] - self.awsIamHeaders["Authorization"] = request.headers["Authorization"] - - if self.jwtToken: - self._verify_jwt_token_support() - - self.asynchronous = False - - logger.info("exit: __init__") - - def _set_auth_header(self): - """Set the authentication header based on available tokens or credentials.""" - if self.jwtToken: - return {"Authorization": "Bearer " + self.jwtToken} - elif self.apiToken: - return {"Authorization": "Bearer " + self.apiToken} - else: - return {"Authorization": "Basic {0}".format(self.base64_credential)} + "The default graphname 'MyGraph' is deprecated. Please explicitly specify your graph name.", + DeprecationWarning + ) def _verify_jwt_token_support(self): try: @@ -305,8 +178,8 @@ def _req(self, method: str, url: str, authMode: str = "token", headers: dict = N Returns: The (relevant part of the) response from the request (as a dictionary). """ - _headers, _data, verify = self._prep_req( - authMode, headers, url, method, data) + # Deprecated: authMode + _headers, _data, verify = self._prep_req(headers, url, method, data) if "GSQL-TIMEOUT" in _headers: http_timeout = (10, int(int(_headers["GSQL-TIMEOUT"])/1000) + 10) @@ -361,6 +234,7 @@ def _req(self, method: str, url: str, authMode: str = "token", headers: dict = N self.restppUrl = newRestppUrl self.restppPort = self.gsPort else: + e.add_note(f"headers: {_headers}") raise e return self._parse_req(res, jsonResponse, strictJson, skipCheck, resKey) @@ -569,3 +443,11 @@ def _version_greater_than_4_0(self) -> bool: if version[0] >= "4" and version[1] > "0": return True return False + + def _validate_graphname(self, operation_name=""): + """Validate that graphname is set for operations that require it.""" + if not self.graphname: + raise TigerGraphException( + f"Graph name is required for {operation_name}. Please specify graphname when creating the connection.", + "E-0004" + ) diff --git a/pyTigerGraph/pyTigerGraphSchema.py b/pyTigerGraph/pyTigerGraphSchema.py index 1c6e52a0..d55496da 100644 --- a/pyTigerGraph/pyTigerGraphSchema.py +++ b/pyTigerGraph/pyTigerGraphSchema.py @@ -13,6 +13,7 @@ _prep_upsert_data, _prep_get_endpoints ) +from pyTigerGraph.common.exception import TigerGraphException from pyTigerGraph.pyTigerGraphBase import pyTigerGraphBase logger = logging.getLogger(__name__) @@ -84,6 +85,49 @@ def getSchema(self, udts: bool = True, force: bool = False) -> dict: return self.schema + def getSchemaVer(self) -> int: + """Retrieves the schema version of the graph by running an interpreted query. + + Returns: + The schema version as an integer. + + Endpoint: + - `POST /gsqlserver/interpreted_query` (In TigerGraph versions 3.x) + - `POST /gsql/v1/queries/interpret` (In TigerGraph versions 4.x) + """ + logger.info("entry: getSchemaVer") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + # Create the interpreted query to get schema version + query_text = f'INTERPRET QUERY () FOR GRAPH {self.graphname} {{ PRINT "OK"; }}' + + try: + # Run the interpreted query + if self._version_greater_than_4_0(): + ret = self._req("POST", self.gsUrl + "/gsql/v1/queries/interpret", + params={}, data=query_text, authMode="pwd", resKey="version", + headers={'Content-Type': 'text/plain'}) + else: + ret = self._req("POST", self.gsUrl + "/gsqlserver/interpreted_query", data=query_text, + params={}, authMode="pwd", resKey="version") + + schema_version_int = None + if isinstance(ret, dict) and "schema" in ret: + schema_version = ret["schema"] + try: + schema_version_int = int(schema_version) + except (ValueError, TypeError): + logger.warning(f"Schema version '{schema_version}' could not be converted to integer") + if schema_version_int is None: + logger.warning("Schema version not found in query result") + logger.info("exit: _get_schema_ver") + return schema_version_int + + except Exception as e: + logger.error(f"Error getting schema version: {str(e)}") + raise Exception(f"Failed to get schema version: {str(e)}") + def upsertData(self, data: Union[str, object], atomic: bool = False, ackAll: bool = False, newVertexOnly: bool = False, vertexMustExist: bool = False, updateVertexOnly: bool = False) -> dict: @@ -130,7 +174,7 @@ def upsertData(self, data: Union[str, object], atomic: bool = False, ackAll: boo if logger.level == logging.DEBUG: logger.debug("return: " + str(res)) - logger.info("exit: getSchema") + logger.info("exit: upsertData") return res diff --git a/pyTigerGraph/pytgasync/pyTigerGraph.py b/pyTigerGraph/pytgasync/pyTigerGraph.py index 357e775c..1f5a7cd2 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraph.py +++ b/pyTigerGraph/pytgasync/pyTigerGraph.py @@ -26,7 +26,7 @@ class AsyncTigerGraphConnection(AsyncPyTigerGraphVertex, AsyncPyTigerGraphEdge, AsyncPyTigerGraphLoading, AsyncPyTigerGraphPath, AsyncPyTigerGraphDataset, object): """Python wrapper for TigerGraph's REST++ and GSQL APIs""" - def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", + def __init__(self, host: str = "http://127.0.0.1", graphname: str = "", gsqlSecret: str = "", username: str = "tigergraph", password: str = "tigergraph", tgCloud: bool = False, restppPort: Union[int, str] = "9000", gsPort: Union[int, str] = "14240", gsqlVersion: str = "", version: str = "", diff --git a/pyTigerGraph/pytgasync/pyTigerGraphAuth.py b/pyTigerGraph/pytgasync/pyTigerGraphAuth.py index 4849b1d7..948fd5ab 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphAuth.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphAuth.py @@ -192,6 +192,7 @@ async def getToken(self, secret: str = None, setToken: bool = True, lifetime: in self.apiToken = token self.authHeader = auth_header + self.authMode = "token" logger.info("exit: getToken") return token diff --git a/pyTigerGraph/pytgasync/pyTigerGraphBase.py b/pyTigerGraph/pytgasync/pyTigerGraphBase.py index 3da08e2d..9edaf5de 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphBase.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphBase.py @@ -16,7 +16,7 @@ conn = AsyncTigerGraphConnection( host="http://localhost", - graphname="MyGraph", + graphname="", username="tigergraph", password="tigergraph") @@ -39,7 +39,7 @@ class AsyncPyTigerGraphBase(PyTigerGraphCore): - def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", + def __init__(self, host: str = "http://127.0.0.1", graphname: str = "", gsqlSecret: str = "", username: str = "tigergraph", password: str = "tigergraph", tgCloud: bool = False, restppPort: Union[int, str] = "9000", gsPort: Union[int, str] = "14240", gsqlVersion: str = "", version: str = "", @@ -130,8 +130,8 @@ async def _req(self, method: str, url: str, authMode: str = "token", headers: di Returns: The (relevant part of the) response from the request (as a dictionary). """ - _headers, _data, verify = self._prep_req( - authMode, headers, url, method, data) + # Deprecated: authMode + _headers, _data, verify = self._prep_req(headers, url, method, data) if "GSQL-TIMEOUT" in _headers: http_timeout = (10, int(int(_headers["GSQL-TIMEOUT"])/1000) + 10) diff --git a/pyTigerGraph/pytgasync/pyTigerGraphQuery.py b/pyTigerGraph/pytgasync/pyTigerGraphQuery.py index ac97d7e1..50d18c3e 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphQuery.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphQuery.py @@ -4,7 +4,7 @@ All functions in this module are called as methods on a link:https://docs.tigergraph.com/pytigergraph/current/core-functions/base[`TigerGraphConnection` object]. """ import logging -import time +import asyncio from typing import TYPE_CHECKING, Union, Optional diff --git a/pyTigerGraph/pytgasync/pyTigerGraphSchema.py b/pyTigerGraph/pytgasync/pyTigerGraphSchema.py index 35f2f1f6..4966ac99 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphSchema.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphSchema.py @@ -14,6 +14,7 @@ _prep_upsert_data, _prep_get_endpoints ) +from pyTigerGraph.common.exception import TigerGraphException logger = logging.getLogger(__name__) @@ -84,6 +85,49 @@ async def getSchema(self, udts: bool = True, force: bool = False) -> dict: return self.schema + async def getSchemaVer(self) -> int: + """Retrieves the schema version of the graph by running an interpreted query. + + Returns: + The schema version as an integer. + + Endpoint: + - `POST /gsqlserver/interpreted_query` (In TigerGraph versions 3.x) + - `POST /gsql/v1/queries/interpret` (In TigerGraph versions 4.x) + """ + logger.info("entry: getSchemaVer") + if logger.level == logging.DEBUG: + logger.debug("params: " + self._locals(locals())) + + # Create the interpreted query to get schema version + query_text = f'INTERPRET QUERY () FOR GRAPH {self.graphname} {{ PRINT "OK"; }}' + + try: + # Run the interpreted query + if await self._version_greater_than_4_0(): + ret = await self._req("POST", self.gsUrl + "/gsql/v1/queries/interpret", + params={}, data=query_text, authMode="pwd", resKey="version", + headers={'Content-Type': 'text/plain'}) + else: + ret = await self._req("POST", self.gsUrl + "/gsqlserver/interpreted_query", data=query_text, + params={}, authMode="pwd", resKey="version") + + schema_version_int = None + if isinstance(ret, dict) and "schema" in ret: + schema_version = ret["schema"] + try: + schema_version_int = int(schema_version) + except (ValueError, TypeError): + logger.warning(f"Schema version '{schema_version}' could not be converted to integer") + if schema_version_int is None: + logger.warning("Schema version not found in query result") + logger.info("exit: _get_schema_ver") + return schema_version_int + + except Exception as e: + logger.error(f"Error getting schema version: {str(e)}") + raise Exception(f"Failed to get schema version: {str(e)}") + async def upsertData(self, data: Union[str, object], atomic: bool = False, ackAll: bool = False, newVertexOnly: bool = False, vertexMustExist: bool = False, updateVertexOnly: bool = False) -> dict: @@ -131,7 +175,7 @@ async def upsertData(self, data: Union[str, object], atomic: bool = False, ackAl if logger.level == logging.DEBUG: logger.debug("return: " + str(res)) - logger.info("exit: getSchema") + logger.info("exit: upsertData") return res