Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 59 additions & 6 deletions timesketch/lib/llms/providers/secgemini_log_analyzer_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
import asyncio
import pathlib
import tempfile
from datetime import datetime
from typing import Any, Dict, Generator, Iterable, Optional

from flask import current_app

from timesketch.lib.llms.providers import interface
from timesketch.lib.llms.providers import manager

Expand Down Expand Up @@ -88,6 +92,7 @@ async def _run_async_stream(self, log_path, prompt):
1. Creates a new SecGemini session.
2. Uploads the local log file to the session.
3. Streams the analysis results for the given prompt.
4. If debugging is enabled, streams the raw sec-gemini response to a log.

Args:
log_path (Path): The local filesystem path to the JSONL log file.
Expand All @@ -100,6 +105,7 @@ async def _run_async_stream(self, log_path, prompt):
model=self.model, enable_logging=self.enable_logging
)
self.session_id = self._session.id
# TODO: Could we check if the API key has logging enabled and if not ERR
logger.info("Started new SecGemini session: '%s'", self._session.id)
self._session.upload_and_attach_logs(
log_path, custom_fields_mapping=self.custom_fields_mapping
Expand All @@ -121,12 +127,59 @@ async def _run_async_stream(self, log_path, prompt):
"log are expected. The client automatically reconnects during "
"long-running analysis."
)
async for response in self._session.stream(prompt):
if (
response.message_type == MessageType.RESULT
and response.actor == "summarization_agent"
):
yield response.content

debug_log_file = None
if current_app.config.get("DEBUG"):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_filename = f"secgemini_response_{timestamp}_{self.session_id}.log"
log_file_path = os.path.join(tempfile.gettempdir(), log_filename)
flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
try:
debug_log_file = os.fdopen(
os.open(log_file_path, flags, 0o600), "w", encoding="utf-8"
)
logger.info(
"SecGemini raw response is being streamed to: %s", log_file_path
)
except (IOError, FileExistsError) as e:
logger.error(
"Failed to create SecGemini debug log at %s: %s",
log_file_path,
e,
exc_info=True,
)
debug_log_file = None

try:
async for response in self._session.stream(prompt):
if debug_log_file:
try:
if hasattr(response, "to_json") and callable(
getattr(response, "to_json")
):
json_bytes = response.to_json()
json_string = json_bytes.decode("utf-8")
debug_log_file.write(json_string + "\n")
else:
debug_log_file.write(str(response) + "\n")
debug_log_file.flush()
except IOError as e:
logger.error(
"Failed to write to SecGemini debug log: %s",
e,
exc_info=True,
)

if (
response.message_type == MessageType.RESULT
and response.actor == "summarization_agent"
):
content_chunk = response.content
yield content_chunk
finally:
if debug_log_file:
debug_log_file.close()
logger.info("Finished writing SecGemini debug log: %s", log_file_path)

def generate_stream_from_logs(
self,
Expand Down
Loading