Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyTigerGraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from pyTigerGraph.pytgasync.pyTigerGraph import AsyncTigerGraphConnection
from pyTigerGraph.common.exception import TigerGraphException

__version__ = "1.9.0"
__version__ = "1.9.1"

__license__ = "Apache 2"
2 changes: 1 addition & 1 deletion pyTigerGraph/common/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _prep_token_request(restppUrl: str,
else:
method = "POST"
url = gsUrl + "/gsql/v1/tokens" # used for TG 4.x
data = {"graph": graphname}
data = {"graph": graphname} if graphname else {}

# alt_url and alt_data used to construct the method and url for functions run in TG version 3.x
alt_url = restppUrl+"/requesttoken" # used for TG 3.x
Expand Down
56 changes: 25 additions & 31 deletions pyTigerGraph/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def excepthook(type, value, traceback):


class PyTigerGraphCore(object):
def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph",
def __init__(self, host: str = "http://127.0.0.1", graphname: str = "",
gsqlSecret: str = "", username: str = "tigergraph", password: str = "tigergraph",
tgCloud: bool = False, restppPort: Union[int, str] = "9000",
gsPort: Union[int, str] = "14240", gsqlVersion: str = "", version: str = "",
Expand Down Expand Up @@ -110,7 +110,8 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph",
self.base64_credential = base64.b64encode(
"{0}:{1}".format(self.username, self.password).encode("utf-8")).decode("utf-8")

self.authHeader = self._set_auth_header()
# Detect auth mode automatically by checking if jwtToken or apiToken is provided
self.authHeader, self.authMode = self._set_auth_header()

# TODO Eliminate version and use gsqlVersion only, meaning TigerGraph server version
if gsqlVersion:
Expand Down Expand Up @@ -179,7 +180,7 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph",
self.restppPort = restppPort
self.restppUrl = self.host + ":" + self.restppPort

self.gsPort = ""
self.gsPort = gsPort
if self.tgCloud and (gsPort == "14240" or gsPort == "443"):
self.gsPort = sslPort
self.gsUrl = self.host + ":" + sslPort
Expand Down Expand Up @@ -216,11 +217,11 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph",
def _set_auth_header(self):
"""Set the authentication header based on available tokens or credentials."""
if self.jwtToken:
return {"Authorization": "Bearer " + self.jwtToken}
return {"Authorization": "Bearer " + self.jwtToken}, "token"
elif self.apiToken:
return {"Authorization": "Bearer " + self.apiToken}
return {"Authorization": "Bearer " + self.apiToken}, "token"
else:
return {"Authorization": "Basic {0}".format(self.base64_credential)}
return {"Authorization": "Basic {0}".format(self.base64_credential)}, "pwd"

def _verify_jwt_token_support(self):
try:
Expand Down Expand Up @@ -275,39 +276,32 @@ def _error_check(self, res: dict) -> bool:
)
return False

def _prep_req(self, authMode, headers, url, method, data):
def _prep_req(self, headers, url, method, data):
logger.info("entry: _req")
if logger.level == logging.DEBUG:
logger.debug("params: " + self._locals(locals()))

_headers = {}

# If JWT token is provided, always use jwtToken as token
if authMode == "token":
if isinstance(self.jwtToken, str) and self.jwtToken.strip() != "":
token = self.jwtToken
elif isinstance(self.apiToken, tuple):
token = self.apiToken[0]
elif isinstance(self.apiToken, str) and self.apiToken.strip() != "":
token = self.apiToken
else:
token = None
if isinstance(self.jwtToken, str) and self.jwtToken.strip() != "":
token = self.jwtToken
elif isinstance(self.apiToken, tuple):
token = self.apiToken[0]
elif isinstance(self.apiToken, str) and self.apiToken.strip() != "":
token = self.apiToken
else:
token = None

if token:
self.authHeader = {'Authorization': "Bearer " + token}
_headers = self.authHeader
else:
self.authHeader = {
'Authorization': 'Basic {0}'.format(self.base64_credential)}
_headers = self.authHeader
authMode = 'pwd'

if authMode == "pwd":
if self.jwtToken:
_headers = {'Authorization': "Bearer " + self.jwtToken}
else:
_headers = {'Authorization': 'Basic {0}'.format(
self.base64_credential)}
if token:
self.authHeader = {'Authorization': "Bearer " + token}
_headers = self.authHeader
self.authMode = "token"
else:
self.authHeader = {
'Authorization': 'Basic {0}'.format(self.base64_credential)}
_headers = self.authHeader
self.authMode = 'pwd'

if headers:
_headers.update(headers)
Expand Down
2 changes: 1 addition & 1 deletion pyTigerGraph/pyTigerGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TigerGraphConnection(pyTigerGraphVertex, pyTigerGraphEdge, pyTigerGraphUDT
pyTigerGraphLoading, pyTigerGraphPath, pyTigerGraphDataset, object):
"""Python wrapper for TigerGraph's REST++ and GSQL APIs"""

def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph",
def __init__(self, host: str = "http://127.0.0.1", graphname: str = "",
gsqlSecret: str = "", username: str = "tigergraph", password: str = "tigergraph",
tgCloud: bool = False, restppPort: Union[int, str] = "9000",
gsPort: Union[int, str] = "14240", gsqlVersion: str = "", version: str = "",
Expand Down
1 change: 1 addition & 0 deletions pyTigerGraph/pyTigerGraphAuth.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def getToken(self,
)
self.apiToken = token
self.authHeader = auth_header
self.authMode = "token"

logger.info("exit: getToken")
return token
Expand Down
162 changes: 22 additions & 140 deletions pyTigerGraph/pyTigerGraphBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

conn = TigerGraphConnection(
host="http://localhost",
graphname="MyGraph",
graphname="your_graph_name",
username="tigergraph",
password="tigergraph")

Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph",
protocol (http:// or https://). If `certPath` is `None` and the protocol is https,
a self-signed certificate will be used.
graphname:
The default graph for running queries.
The graph name for running queries. **Required** - must be specified.
gsqlSecret:
The secret key for GSQL. See https://docs.tigergraph.com/tigergraph-server/current/user-access/managing-credentials#_secrets.
username:
Expand Down Expand Up @@ -102,144 +102,17 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph",
TigerGraphException: In case on invalid URL scheme.

"""
logger.info("entry: __init__")
if logger.level == logging.DEBUG:
logger.debug("params: " + self._locals(locals()))
super().__init__(host=host, graphname=graphname, gsqlSecret=gsqlSecret,
username=username, password=password, tgCloud=tgCloud,
restppPort=restppPort, gsPort=gsPort, gsqlVersion=gsqlVersion,
version=version, apiToken=apiToken, useCert=useCert, certPath=certPath,
debug=debug, sslPort=sslPort, gcp=gcp, jwtToken=jwtToken)

inputHost = urlparse(host)
if inputHost.scheme not in ["http", "https"]:
raise TigerGraphException("Invalid URL scheme. Supported schemes are http and https.",
"E-0003")
self.netloc = inputHost.netloc
self.host = "{0}://{1}".format(inputHost.scheme, self.netloc)
if gsqlSecret != "":
self.username = "__GSQL__secret"
self.password = gsqlSecret
else:
self.username = username
self.password = password
self.graphname = graphname
self.responseConfigHeader = {}
self.awsIamHeaders = {}

self.jwtToken = jwtToken
self.apiToken = apiToken
self.base64_credential = base64.b64encode(
"{0}:{1}".format(self.username, self.password).encode("utf-8")).decode("utf-8")

self.authHeader = self._set_auth_header()

# TODO Eliminate version and use gsqlVersion only, meaning TigerGraph server version
if gsqlVersion:
self.version = gsqlVersion
elif version:
if graphname == "MyGraph":
warnings.warn(
"The `version` parameter is deprecated; use the `gsqlVersion` parameter instead.",
DeprecationWarning)
self.version = version
else:
self.version = ""

if debug is not None:
warnings.warn(
"The `debug` parameter is deprecated; configure standard logging in your app.",
DeprecationWarning)
if not debug:
sys.excepthook = excepthook # TODO Why was this necessary? Can it be removed?
sys.tracebacklimit = None

self.schema = None

# TODO Remove useCert parameter
if useCert is not None:
warnings.warn(
"The `useCert` parameter is deprecated; the need for a CA certificate is now determined by URL scheme.",
DeprecationWarning)
if inputHost.scheme == "http":
self.downloadCert = False
self.useCert = False
self.certPath = ""
elif inputHost.scheme == "https":
if not certPath:
self.downloadCert = True
else:
self.downloadCert = False
self.useCert = True
self.certPath = certPath
self.sslPort = str(sslPort)

# TODO Remove gcp parameter
if gcp:
warnings.warn("The `gcp` parameter is deprecated.",
DeprecationWarning)
self.tgCloud = tgCloud or gcp
if "tgcloud" in self.netloc.lower():
try: # If get request succeeds, using TG Cloud instance provisioned after 6/20/2022
self._get(self.host + "/api/ping", resKey="message")
self.tgCloud = True
# If get request fails, using TG Cloud instance provisioned before 6/20/2022, before new firewall config
except requests.exceptions.RequestException:
self.tgCloud = False
except TigerGraphException:
raise (TigerGraphException("Incorrect graphname."))

restppPort = str(restppPort)
gsPort = str(gsPort)
sslPort = str(sslPort)
if restppPort == gsPort:
self.restppPort = restppPort
self.restppUrl = self.host + ":" + restppPort + "/restpp"
elif (self.tgCloud and (restppPort == "9000" or restppPort == "443")):
if restppPort == gsPort:
sslPort = gsPort
self.restppPort = sslPort
self.restppUrl = self.host + ":" + sslPort + "/restpp"
else:
self.restppPort = restppPort
self.restppUrl = self.host + ":" + self.restppPort

self.gsPort = gsPort
if self.tgCloud and (gsPort == "14240" or gsPort == "443"):
self.gsPort = sslPort
self.gsUrl = self.host + ":" + sslPort
else:
self.gsPort = gsPort
self.gsUrl = self.host + ":" + self.gsPort
self.url = ""

if self.username.startswith("arn:aws:iam::"):
import boto3
from botocore.awsrequest import AWSRequest
from botocore.auth import SigV4Auth
# Prepare a GetCallerIdentity request.
request = AWSRequest(
method="POST",
url="https://sts.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15",
headers={
'Host': 'sts.amazonaws.com'
})
# Get headers
SigV4Auth(boto3.Session().get_credentials(),
"sts", "us-east-1").add_auth(request)
self.awsIamHeaders["X-Amz-Date"] = request.headers["X-Amz-Date"]
self.awsIamHeaders["X-Amz-Security-Token"] = request.headers["X-Amz-Security-Token"]
self.awsIamHeaders["Authorization"] = request.headers["Authorization"]

if self.jwtToken:
self._verify_jwt_token_support()

self.asynchronous = False

logger.info("exit: __init__")

def _set_auth_header(self):
"""Set the authentication header based on available tokens or credentials."""
if self.jwtToken:
return {"Authorization": "Bearer " + self.jwtToken}
elif self.apiToken:
return {"Authorization": "Bearer " + self.apiToken}
else:
return {"Authorization": "Basic {0}".format(self.base64_credential)}
"The default graphname 'MyGraph' is deprecated. Please explicitly specify your graph name.",
DeprecationWarning
)

def _verify_jwt_token_support(self):
try:
Expand Down Expand Up @@ -305,8 +178,8 @@ def _req(self, method: str, url: str, authMode: str = "token", headers: dict = N
Returns:
The (relevant part of the) response from the request (as a dictionary).
"""
_headers, _data, verify = self._prep_req(
authMode, headers, url, method, data)
# Deprecated: authMode
_headers, _data, verify = self._prep_req(headers, url, method, data)

if "GSQL-TIMEOUT" in _headers:
http_timeout = (10, int(int(_headers["GSQL-TIMEOUT"])/1000) + 10)
Expand Down Expand Up @@ -361,6 +234,7 @@ def _req(self, method: str, url: str, authMode: str = "token", headers: dict = N
self.restppUrl = newRestppUrl
self.restppPort = self.gsPort
else:
e.add_note(f"headers: {_headers}")
raise e

return self._parse_req(res, jsonResponse, strictJson, skipCheck, resKey)
Expand Down Expand Up @@ -569,3 +443,11 @@ def _version_greater_than_4_0(self) -> bool:
if version[0] >= "4" and version[1] > "0":
return True
return False

def _validate_graphname(self, operation_name=""):
"""Validate that graphname is set for operations that require it."""
if not self.graphname:
raise TigerGraphException(
f"Graph name is required for {operation_name}. Please specify graphname when creating the connection.",
"E-0004"
)
Loading