Skip to content

Commit d025ba8

Browse files
fix: Refactor incident processing to use IncidentBl with sessions (#4013)
1 parent d93ffcd commit d025ba8

File tree

3 files changed

+111
-109
lines changed

3 files changed

+111
-109
lines changed

keep/api/bl/incidents_bl.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import List, Optional
77
from uuid import UUID
88

9+
import asyncio
910
from fastapi import HTTPException
1011
from pusher import Pusher
1112
from sqlalchemy.orm.exc import StaleDataError
@@ -55,6 +56,7 @@
5556

5657

5758
class IncidentBl:
59+
5860
def __init__(
5961
self,
6062
tenant_id: str,
@@ -71,13 +73,13 @@ def __init__(
7173
self.redis = os.environ.get("REDIS", "false") == "true"
7274

7375
def create_incident(
74-
self, incident_dto: IncidentDtoIn, generated_from_ai: bool = False
76+
self, incident_dto: [IncidentDtoIn | IncidentDto], generated_from_ai: bool = False
7577
) -> IncidentDto:
7678
"""
7779
Creates a new incident.
7880
7981
Args:
80-
incident_dto (IncidentDtoIn): The data transfer object containing the details of the incident to be created.
82+
incident_dto (IncidentDtoIn | IncidentDto): The data transfer object containing the details of the incident to be created.
8183
generated_from_ai (bool, optional): Indicates if the incident was generated by Keep's AI. Defaults to False.
8284
8385
Returns:
@@ -111,6 +113,12 @@ def create_incident(
111113
)
112114
return new_incident_dto
113115

116+
def sync_add_alerts_to_incident(self, *args, **kwargs) -> None:
117+
"""
118+
Synchronous wrapper for the async add_alerts_to_incident method.
119+
"""
120+
asyncio.run(self.add_alerts_to_incident(*args, **kwargs))
121+
114122
async def add_alerts_to_incident(
115123
self,
116124
incident_id: UUID,

keep/api/core/db.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -3717,8 +3717,12 @@ def update_incident_from_dto_by_id(
37173717
return incident
37183718

37193719

3720-
def get_incident_by_fingerprint(tenant_id: str, fingerprint: str) -> Optional[Incident]:
3721-
with Session(engine) as session:
3720+
def get_incident_by_fingerprint(
3721+
tenant_id: str,
3722+
fingerprint: str,
3723+
session: Optional[Session] = None
3724+
) -> Optional[Incident]:
3725+
with existed_or_new_session(session) as session:
37223726
return session.exec(
37233727
select(Incident).where(
37243728
Incident.tenant_id == tenant_id, Incident.fingerprint == fingerprint
@@ -3729,10 +3733,11 @@ def get_incident_by_fingerprint(tenant_id: str, fingerprint: str) -> Optional[In
37293733
def delete_incident_by_id(
37303734
tenant_id: str,
37313735
incident_id: UUID,
3736+
session: Optional[Session] = None
37323737
) -> bool:
37333738
if isinstance(incident_id, str):
37343739
incident_id = __convert_to_uuid(incident_id)
3735-
with Session(engine) as session:
3740+
with existed_or_new_session(session) as session:
37363741
incident = session.exec(
37373742
select(Incident).filter(
37383743
Incident.tenant_id == tenant_id,

keep/api/tasks/process_incident_task.py

+93-104
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import logging
22

33
from arq import Retry
4+
from sqlmodel import Session
45

6+
from keep.api.bl.incidents_bl import IncidentBl
57
from keep.api.core.db import (
6-
add_alerts_to_incident,
7-
create_incident_from_dto,
88
get_incident_by_fingerprint,
99
get_incident_by_id,
10-
update_incident_from_dto_by_id,
10+
engine,
1111
)
12-
from keep.api.core.dependencies import get_pusher_client
1312
from keep.api.models.incident import IncidentDto
1413
from keep.api.tasks.process_event_task import process_event
1514

@@ -32,121 +31,111 @@ def process_incident(
3231
"trace_id": trace_id,
3332
}
3433

35-
if ctx and isinstance(ctx, dict):
36-
extra["job_try"] = ctx.get("job_try", 0)
37-
extra["job_id"] = ctx.get("job_id", None)
34+
with Session(engine) as session:
3835

39-
if isinstance(incidents, IncidentDto):
40-
incidents = [incidents]
36+
if ctx and isinstance(ctx, dict):
37+
extra["job_try"] = ctx.get("job_try", 0)
38+
extra["job_id"] = ctx.get("job_id", None)
4139

42-
logger.info(f"Processing {len(incidents)} incidents", extra=extra)
40+
if isinstance(incidents, IncidentDto):
41+
incidents = [incidents]
4342

44-
if logger.getEffectiveLevel() == logging.DEBUG:
45-
# Lets log the incidents in debug mode
46-
extra["incident"] = [i.dict() for i in incidents]
43+
logger.info(f"Processing {len(incidents)} incidents", extra=extra)
4744

48-
try:
49-
for incident in incidents:
50-
logger.info(
51-
f"Processing incident: {incident.id}",
52-
extra={**extra, "fingerprint": incident.fingerprint},
53-
)
54-
55-
incident_from_db = get_incident_by_id(
56-
tenant_id=tenant_id, incident_id=incident.id
57-
)
45+
if logger.getEffectiveLevel() == logging.DEBUG:
46+
# Lets log the incidents in debug mode
47+
extra["incident"] = [i.dict() for i in incidents]
5848

59-
# Try to get by fingerprint if no incident was found by id
60-
if incident_from_db is None and incident.fingerprint:
61-
incident_from_db = get_incident_by_fingerprint(
62-
tenant_id=tenant_id, fingerprint=incident.fingerprint
63-
)
49+
incident_bl = IncidentBl(tenant_id, session)
6450

65-
if incident_from_db:
66-
logger.info(
67-
f"Updating incident: {incident.id}",
68-
extra={**extra, "fingerprint": incident.fingerprint},
69-
)
70-
incident_from_db = update_incident_from_dto_by_id(
71-
tenant_id=tenant_id,
72-
incident_id=incident_from_db.id,
73-
updated_incident_dto=incident,
74-
)
75-
logger.info(
76-
f"Updated incident: {incident.id}",
77-
extra={**extra, "fingerprint": incident.fingerprint},
78-
)
79-
else:
51+
try:
52+
for incident in incidents:
8053
logger.info(
81-
f"Creating incident: {incident.id}",
54+
f"Processing incident: {incident.id}",
8255
extra={**extra, "fingerprint": incident.fingerprint},
8356
)
84-
incident_from_db = create_incident_from_dto(
85-
tenant_id=tenant_id,
86-
incident_dto=incident,
87-
)
88-
logger.info(
89-
f"Created incident: {incident.id}",
90-
extra={**extra, "fingerprint": incident.fingerprint},
57+
58+
incident_from_db = get_incident_by_id(
59+
tenant_id=tenant_id, incident_id=incident.id, session=session
9160
)
9261

93-
try:
94-
if incident.alerts:
95-
logger.info("Adding incident alerts", extra=extra)
96-
processed_alerts = process_event(
97-
{},
98-
tenant_id,
99-
provider_type,
100-
provider_id,
101-
None,
102-
None,
103-
trace_id,
104-
incident.alerts,
62+
# Try to get by fingerprint if no incident was found by id
63+
if incident_from_db is None and incident.fingerprint:
64+
incident_from_db = get_incident_by_fingerprint(
65+
tenant_id=tenant_id, fingerprint=incident.fingerprint, session=session
66+
)
67+
68+
if incident_from_db:
69+
logger.info(
70+
f"Updating incident: {incident.id}",
71+
extra={**extra, "fingerprint": incident.fingerprint},
72+
)
73+
incident_from_db = incident_bl.update_incident(
74+
incident_id=incident_from_db.id,
75+
updated_incident_dto=incident,
76+
generated_by_ai=False,
77+
)
78+
logger.info(
79+
f"Updated incident: {incident.id}",
80+
extra={**extra, "fingerprint": incident.fingerprint},
81+
)
82+
else:
83+
logger.info(
84+
f"Creating incident: {incident.id}",
85+
extra={**extra, "fingerprint": incident.fingerprint},
86+
)
87+
incident_from_db = incident_bl.create_incident(
88+
incident_dto=incident,
89+
)
90+
logger.info(
91+
f"Created incident: {incident.id}",
92+
extra={**extra, "fingerprint": incident.fingerprint},
10593
)
106-
if processed_alerts:
107-
add_alerts_to_incident(
94+
95+
try:
96+
if incident.alerts:
97+
logger.info("Adding incident alerts", extra=extra)
98+
processed_alerts = process_event(
99+
{},
108100
tenant_id,
109-
incident_from_db,
110-
[
111-
processed_alert.event_id
112-
for processed_alert in processed_alerts
113-
],
114-
# Because the incident was created with the alerts count, we need to override it
115-
# otherwise it will be the sum of the previous count + the newly attached alerts count
116-
override_count=True,
117-
)
118-
logger.info("Added incident alerts", extra=extra)
119-
else:
120-
logger.info(
121-
"No alerts to add to incident, probably deduplicated",
122-
extra=extra,
101+
provider_type,
102+
provider_id,
103+
None,
104+
None,
105+
trace_id,
106+
incident.alerts,
123107
)
124-
except Exception:
125-
logger.exception("Error adding incident alerts", extra=extra)
126-
logger.info("Processed incident", extra=extra)
127-
128-
pusher_client = get_pusher_client()
129-
if not pusher_client:
130-
pass
131-
try:
132-
pusher_client.trigger(
133-
f"private-{tenant_id}",
134-
"incident-change",
135-
{},
136-
)
108+
if processed_alerts:
109+
incident_bl.sync_add_alerts_to_incident(
110+
incident_from_db.id,
111+
[
112+
processed_alert.fingerprint
113+
for processed_alert in processed_alerts
114+
],
115+
# Because the incident was created with the alerts count, we need to override it
116+
# otherwise it will be the sum of the previous count + the newly attached alerts count
117+
override_count=True,
118+
)
119+
logger.info("Added incident alerts", extra=extra)
120+
else:
121+
logger.info(
122+
"No alerts to add to incident, probably deduplicated",
123+
extra=extra,
124+
)
125+
except Exception:
126+
logger.exception("Error adding incident alerts", extra=extra)
127+
logger.info("Processed incident", extra=extra)
128+
129+
logger.info("Processed all incidents", extra=extra)
137130
except Exception:
138-
logger.exception("Failed to push incidents to the client")
139-
140-
logger.info("Processed all incidents", extra=extra)
141-
except Exception:
142-
logger.exception(
143-
"Error processing incidents",
144-
extra=extra,
145-
)
146-
147-
# Retrying only if context is present (running the job in arq worker)
148-
if bool(ctx):
149-
raise Retry(defer=ctx["job_try"] * TIMES_TO_RETRY_JOB)
131+
logger.exception(
132+
"Error processing incidents",
133+
extra=extra,
134+
)
135+
136+
# Retrying only if context is present (running the job in arq worker)
137+
if bool(ctx):
138+
raise Retry(defer=ctx["job_try"] * TIMES_TO_RETRY_JOB)
150139

151140

152141
async def async_process_incident(*args, **kwargs):

0 commit comments

Comments
 (0)