Skip to content

Commit

Permalink
Wrap all sqlalchemy usages of session into a block (#1706)
Browse files Browse the repository at this point in the history
  • Loading branch information
sausage-todd authored Oct 17, 2023
1 parent 99d9c81 commit 6845ad5
Showing 1 changed file with 72 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def __init__(self, tenant_id="", db_url=False, test=False, send=True):
)

Base.metadata.create_all(self.engine, checkfirst=True)
Session = sessionmaker(bind=self.engine)
self.session = Session()
self.Session = sessionmaker(bind=self.engine)

self.tenant_id = tenant_id
self.send = send
Expand Down Expand Up @@ -106,25 +105,25 @@ def find_in_table(self, table, query, many=False):
dict: document
"""

search_query = self.session.query(table)
for attr, value in query.items():

# Check if query is nested
nested_count = attr.count(".")
# If nested
if nested_count > 0:
attributes = attr.split(".")
nested_attributes = tuple(attributes[1:])
# Define nested expression
expr = getattr(table, attributes[0])[nested_attributes]
# Execute search_query
search_query = search_query.filter(expr == json.dumps(value))
else:
search_query = search_query.filter(getattr(table, attr) == value)
with self.Session() as session:
search_query = session.query(table)
for attr, value in query.items():
# Check if query is nested
nested_count = attr.count(".")
# If nested
if nested_count > 0:
attributes = attr.split(".")
nested_attributes = tuple(attributes[1:])
# Define nested expression
expr = getattr(table, attributes[0])[nested_attributes]
# Execute search_query
search_query = search_query.filter(expr == json.dumps(value))
else:
search_query = search_query.filter(getattr(table, attr) == value)

if many:
return search_query.all()
return search_query.first()
if many:
return search_query.all()
return search_query.first()

def find_by_id(self, table, id):
"""
Expand All @@ -138,15 +137,17 @@ def find_by_id(self, table, id):
dict: the document
"""

return self.session.query(table).get(id)
with self.Session() as session:
return session.query(table).get(id)

def find_all_usernames(self):
with self.engine.connect() as con:
return con.execute(
f"""select m."id", mw."username", m."displayName", m."emails"
from "members" m
inner join "memberActivityAggregatesMVs" mw on m.id = mw.id
where m."tenantId" = '{self.tenant_id}'""").fetchall()
where m."tenantId" = '{self.tenant_id}'"""
).fetchall()

def find_all(
self, table, ignore_tenant: "bool" = False, query: "dict" = None, order: "dict" = None
Expand All @@ -173,29 +174,30 @@ def find_all(
**{dbk.TENANT: uuid.UUID(self.tenant_id)},
}

search_query = self.session.query(table)
for attr, value in query.items():
# Check if query is nested
nested_count = attr.count(".")
# If nested
if nested_count > 0:
attributes = attr.split(".")
nested_attributes = tuple(attributes[1:])
# Define nested expression
expr = getattr(table, attributes[0])[nested_attributes]
# Execute search_query
search_query = search_query.filter(expr == json.dumps(value))
else:
search_query = search_query.filter(getattr(table, attr) == value)

if order:
for key, value in order.items():
if value:
search_query = search_query.order_by(asc(key))
with self.Session() as session:
search_query = session.query(table)
for attr, value in query.items():
# Check if query is nested
nested_count = attr.count(".")
# If nested
if nested_count > 0:
attributes = attr.split(".")
nested_attributes = tuple(attributes[1:])
# Define nested expression
expr = getattr(table, attributes[0])[nested_attributes]
# Execute search_query
search_query = search_query.filter(expr == json.dumps(value))
else:
search_query = search_query.order_by(desc(key))
search_query = search_query.filter(getattr(table, attr) == value)

return search_query.all()
if order:
for key, value in order.items():
if value:
search_query = search_query.order_by(asc(key))
else:
search_query = search_query.order_by(desc(key))

return search_query.all()

def find_activities(self, search_filters=None):
if not search_filters:
Expand All @@ -208,22 +210,23 @@ def count(self, table, search_filters=None):

search_filters[dbk.TENANT] = uuid.UUID(self.tenant_id)

search_query = self.session.query(table)
for attr, value in search_filters.items():
# Check if query is nested
nested_count = attr.count(".")
# If nested
if nested_count > 0:
attributes = attr.split(".")
nested_attributes = tuple(attributes[1:])
# Define nested expression
expr = getattr(table, attributes[0])[nested_attributes]
# Execute query
search_query = search_query.filter(expr == json.dumps(value))
else:
search_query = search_query.filter(getattr(table, attr) == value)
with self.Session() as session:
search_query = session.query(table)
for attr, value in search_filters.items():
# Check if query is nested
nested_count = attr.count(".")
# If nested
if nested_count > 0:
attributes = attr.split(".")
nested_attributes = tuple(attributes[1:])
# Define nested expression
expr = getattr(table, attributes[0])[nested_attributes]
# Execute query
search_query = search_query.filter(expr == json.dumps(value))
else:
search_query = search_query.filter(getattr(table, attr) == value)

return search_query.count()
return search_query.count()

def find_available_microservices(self, service):
"""
Expand Down Expand Up @@ -253,16 +256,17 @@ def find_new_members(self, microservice, query: "dict" = None) -> "list[dict]":
**{dbk.TENANT: uuid.UUID(self.tenant_id)},
}

search_query = self.session.query(Member)
with self.Session() as session:
search_query = session.query(Member)

# Filter with query
for attr, value in query.items():
search_query = search_query.filter(getattr(Member, attr) == value)
# Filter with query
for attr, value in query.items():
search_query = search_query.filter(getattr(Member, attr) == value)

# Find members that are new
# We use a security padding of 5 minutes
search_query = search_query.filter(
Member.createdAt >= (microservice.updatedAt - timedelta(minutes=5))
).order_by(Member.createdAt.desc())
# Find members that are new
# We use a security padding of 5 minutes
search_query = search_query.filter(
Member.createdAt >= (microservice.updatedAt - timedelta(minutes=5))
).order_by(Member.createdAt.desc())

return search_query.all()
return search_query.all()

0 comments on commit 6845ad5

Please sign in to comment.