-
Notifications
You must be signed in to change notification settings - Fork 722
Expand file tree
/
Copy pathmemory_engine.py
More file actions
7863 lines (6864 loc) · 336 KB
/
memory_engine.py
File metadata and controls
7863 lines (6864 loc) · 336 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Memory Engine for Memory Banks.
This implements a sophisticated memory architecture that combines:
1. Temporal links: Memories connected by time proximity
2. Semantic links: Memories connected by meaning/similarity
3. Entity links: Memories connected by shared entities (PERSON, ORG, etc.)
4. Spreading activation: Search through the graph with activation decay
5. Dynamic weighting: Recency and frequency-based importance
"""
import asyncio
import contextvars
import json
import logging
import time
import uuid
from collections.abc import Awaitable, Callable
from datetime import UTC, datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any
import asyncpg
import httpx
import tiktoken
from ..config import get_config
from ..metrics import get_metrics_collector
from ..tracing import create_operation_span
from ..utils import mask_network_location
from ..worker.exceptions import RetryTaskAt
from .db_budget import budgeted_operation
from .operation_metadata import (
BatchRetainChildMetadata,
BatchRetainParentMetadata,
ConsolidationMetadata,
RefreshMentalModelMetadata,
RetainMetadata,
)
# Context variable for current schema (async-safe, per-task isolation)
# Note: default is None, actual default comes from config via get_current_schema()
_current_schema: contextvars.ContextVar[str | None] = contextvars.ContextVar("current_schema", default=None)
def get_current_schema() -> str:
"""Get the current schema from context (falls back to config default)."""
schema = _current_schema.get()
if schema is None:
# Fall back to configured default schema
return get_config().database_schema
return schema
def count_tokens(text: str) -> int:
"""Count tokens in text using tiktoken (cl100k_base encoding for GPT-4/3.5)."""
return len(_get_tiktoken_encoding().encode(text))
def fq_table(table_name: str) -> str:
"""
Get fully-qualified table name with current schema.
Example:
fq_table("memory_units") -> "public.memory_units"
fq_table("memory_units") -> "tenant_xyz.memory_units" (if schema is set)
"""
return f"{get_current_schema()}.{table_name}"
def _json_default(obj: Any) -> str:
"""JSON serializer for types commonly carried through async task payloads."""
if isinstance(obj, datetime):
return obj.isoformat()
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
# Tables that must be schema-qualified (for runtime validation)
_PROTECTED_TABLES = frozenset(
[
"memory_units",
"memory_links",
"unit_entities",
"entities",
"entity_cooccurrences",
"banks",
"documents",
"chunks",
"async_operations",
"file_storage",
]
)
# Enable runtime SQL validation (can be disabled in production for performance)
_VALIDATE_SQL_SCHEMAS = True
class UnqualifiedTableError(Exception):
"""Raised when SQL contains unqualified table references."""
pass
def validate_sql_schema(sql: str) -> None:
"""
Validate that SQL doesn't contain unqualified table references.
This is a runtime safety check to prevent cross-tenant data access.
Raises UnqualifiedTableError if any protected table is referenced
without a schema prefix.
Args:
sql: The SQL query to validate
Raises:
UnqualifiedTableError: If unqualified table reference found
"""
if not _VALIDATE_SQL_SCHEMAS:
return
import re
sql_upper = sql.upper()
for table in _PROTECTED_TABLES:
table_upper = table.upper()
# Pattern: SQL keyword followed by unqualified table name
# Matches: FROM memory_units, JOIN memory_units, INTO memory_units, UPDATE memory_units
patterns = [
rf"FROM\s+{table_upper}(?:\s|$|,|\)|;)",
rf"JOIN\s+{table_upper}(?:\s|$|,|\)|;)",
rf"INTO\s+{table_upper}(?:\s|$|\()",
rf"UPDATE\s+{table_upper}(?:\s|$)",
rf"DELETE\s+FROM\s+{table_upper}(?:\s|$|;)",
]
for pattern in patterns:
match = re.search(pattern, sql_upper)
if match:
# Check if it's actually qualified (preceded by schema.)
# Look backwards from match to see if there's a dot
start = match.start()
# Find the table name position in the match
table_pos = sql_upper.find(table_upper, start)
if table_pos > 0:
# Check character before table name (skip whitespace)
prefix = sql[:table_pos].rstrip()
if not prefix.endswith("."):
raise UnqualifiedTableError(
f"Unqualified table reference '{table}' in SQL. "
f"Use fq_table('{table}') for schema safety. "
f"SQL snippet: ...{sql[max(0, start - 10) : start + 50]}..."
)
import asyncpg
import numpy as np
from pydantic import BaseModel, Field
from .cross_encoder import CrossEncoderModel
from .embeddings import Embeddings, create_embeddings_from_env
from .interface import MemoryEngineInterface
if TYPE_CHECKING:
from hindsight_api.extensions import OperationValidatorExtension, TenantExtension
from hindsight_api.models import RequestContext
from enum import Enum
from ..metrics import get_metrics_collector
from ..pg0 import EmbeddedPostgres, parse_pg0_url
from .entity_resolver import EntityResolver
from .llm_wrapper import LLMConfig, requires_api_key, sanitize_llm_output
from .query_analyzer import QueryAnalyzer
from .reflect import run_reflect_agent
from .reflect.tools import tool_expand, tool_recall, tool_search_mental_models, tool_search_observations
from .response_models import (
VALID_RECALL_FACT_TYPES,
EntityObservation,
EntityState,
LLMCallTrace,
MemoryFact,
ObservationRef,
ReflectResult,
TokenUsage,
ToolCallTrace,
)
from .response_models import RecallResult as RecallResultModel
from .retain import bank_utils, embedding_utils
from .retain.types import RetainContentDict
from .search import think_utils
from .search.reranking import CrossEncoderReranker, apply_combined_scoring
from .search.tags import TagGroup, TagsMatch, build_tags_where_clause
from .task_backend import BrokerTaskBackend, SyncTaskBackend, TaskBackend
class Budget(str, Enum):
"""Budget levels for recall/reflect operations."""
LOW = "low"
MID = "mid"
HIGH = "high"
def utcnow():
"""Get current UTC time with timezone info."""
return datetime.now(UTC)
# Logger for memory system
logger = logging.getLogger(__name__)
from .db_utils import acquire_with_retry
# Cache tiktoken encoding for token budget filtering (module-level singleton)
_TIKTOKEN_ENCODING = None
def _get_tiktoken_encoding():
"""Get cached tiktoken encoding (cl100k_base for GPT-4/3.5)."""
global _TIKTOKEN_ENCODING
if _TIKTOKEN_ENCODING is None:
_TIKTOKEN_ENCODING = tiktoken.get_encoding("cl100k_base")
return _TIKTOKEN_ENCODING
class MemoryEngine(MemoryEngineInterface):
"""
Advanced memory system using temporal and semantic linking with PostgreSQL.
This class provides:
- Embedding generation for semantic search
- Entity, temporal, and semantic link creation
- Think operations for formulating answers with opinions
- bank profile and disposition management
"""
def __init__(
self,
db_url: str | None = None,
memory_llm_provider: str | None = None,
memory_llm_api_key: str | None = None,
memory_llm_model: str | None = None,
memory_llm_base_url: str | None = None,
# Per-operation LLM config (optional, falls back to memory_llm_* params)
retain_llm_provider: str | None = None,
retain_llm_api_key: str | None = None,
retain_llm_model: str | None = None,
retain_llm_base_url: str | None = None,
reflect_llm_provider: str | None = None,
reflect_llm_api_key: str | None = None,
reflect_llm_model: str | None = None,
reflect_llm_base_url: str | None = None,
consolidation_llm_provider: str | None = None,
consolidation_llm_api_key: str | None = None,
consolidation_llm_model: str | None = None,
consolidation_llm_base_url: str | None = None,
embeddings: Embeddings | None = None,
cross_encoder: CrossEncoderModel | None = None,
query_analyzer: QueryAnalyzer | None = None,
pool_min_size: int | None = None,
pool_max_size: int | None = None,
db_command_timeout: int | None = None,
db_acquire_timeout: int | None = None,
task_backend: TaskBackend | None = None,
run_migrations: bool = True,
operation_validator: "OperationValidatorExtension | None" = None,
tenant_extension: "TenantExtension | None" = None,
skip_llm_verification: bool | None = None,
lazy_reranker: bool | None = None,
):
"""
Initialize the temporal + semantic memory system.
All parameters are optional and will be read from environment variables if not provided.
See hindsight_api.config for environment variable names and defaults.
Args:
db_url: PostgreSQL connection URL. Defaults to HINDSIGHT_API_DATABASE_URL env var or "pg0".
Also supports pg0 URLs: "pg0" or "pg0://instance-name" or "pg0://instance-name:port"
memory_llm_provider: LLM provider. Defaults to HINDSIGHT_API_LLM_PROVIDER env var or "groq".
memory_llm_api_key: API key for the LLM provider. Defaults to HINDSIGHT_API_LLM_API_KEY env var.
memory_llm_model: Model name. Defaults to HINDSIGHT_API_LLM_MODEL env var.
memory_llm_base_url: Base URL for the LLM API. Defaults based on provider.
retain_llm_provider: LLM provider for retain operations. Falls back to memory_llm_provider.
retain_llm_api_key: API key for retain LLM. Falls back to memory_llm_api_key.
retain_llm_model: Model for retain operations. Falls back to memory_llm_model.
retain_llm_base_url: Base URL for retain LLM. Falls back to memory_llm_base_url.
reflect_llm_provider: LLM provider for reflect operations. Falls back to memory_llm_provider.
reflect_llm_api_key: API key for reflect LLM. Falls back to memory_llm_api_key.
reflect_llm_model: Model for reflect operations. Falls back to memory_llm_model.
reflect_llm_base_url: Base URL for reflect LLM. Falls back to memory_llm_base_url.
consolidation_llm_provider: LLM provider for consolidation operations. Falls back to memory_llm_provider.
consolidation_llm_api_key: API key for consolidation LLM. Falls back to memory_llm_api_key.
consolidation_llm_model: Model for consolidation operations. Falls back to memory_llm_model.
consolidation_llm_base_url: Base URL for consolidation LLM. Falls back to memory_llm_base_url.
embeddings: Embeddings implementation. If not provided, created from env vars.
cross_encoder: Cross-encoder model. If not provided, created from env vars.
query_analyzer: Query analyzer implementation. If not provided, uses DateparserQueryAnalyzer.
pool_min_size: Minimum number of connections in the pool. Defaults to HINDSIGHT_API_DB_POOL_MIN_SIZE.
pool_max_size: Maximum number of connections in the pool. Defaults to HINDSIGHT_API_DB_POOL_MAX_SIZE.
db_command_timeout: PostgreSQL command timeout in seconds. Defaults to HINDSIGHT_API_DB_COMMAND_TIMEOUT.
db_acquire_timeout: Connection acquisition timeout in seconds. Defaults to HINDSIGHT_API_DB_ACQUIRE_TIMEOUT.
task_backend: Custom task backend. If not provided, uses BrokerTaskBackend for distributed processing.
run_migrations: Whether to run database migrations during initialize(). Default: True
operation_validator: Optional extension to validate operations before execution.
If provided, retain/recall/reflect operations will be validated.
tenant_extension: Optional extension for multi-tenancy and API key authentication.
If provided, operations require a RequestContext for authentication.
skip_llm_verification: Skip LLM connection verification during initialization.
Defaults to HINDSIGHT_API_SKIP_LLM_VERIFICATION env var or False.
lazy_reranker: Delay reranker initialization until first use. Useful for retain-only
operations that don't need the cross-encoder. Defaults to
HINDSIGHT_API_LAZY_RERANKER env var or False.
"""
# Load config from environment for any missing parameters
from ..config import get_config
config = get_config()
# Apply optimization flags from config if not explicitly provided
self._skip_llm_verification = (
skip_llm_verification if skip_llm_verification is not None else config.skip_llm_verification
)
self._lazy_reranker = lazy_reranker if lazy_reranker is not None else config.lazy_reranker
# Apply defaults from config
db_url = db_url or config.database_url
memory_llm_provider = memory_llm_provider or config.llm_provider
memory_llm_api_key = memory_llm_api_key or config.llm_api_key
if not memory_llm_api_key and requires_api_key(memory_llm_provider):
raise ValueError("LLM API key is required. Set HINDSIGHT_API_LLM_API_KEY environment variable.")
memory_llm_model = memory_llm_model or config.llm_model
memory_llm_base_url = memory_llm_base_url or config.get_llm_base_url() or None
# Track pg0 instance (if used)
self._pg0: EmbeddedPostgres | None = None
# Initialize PostgreSQL connection URL
# The actual URL will be set during initialize() after starting the server
# Supports: "pg0" (default instance), "pg0://instance-name" (named instance), or regular postgresql:// URL
self._use_pg0, self._pg0_instance_name, self._pg0_port = parse_pg0_url(db_url)
if self._use_pg0:
self.db_url = None
else:
self.db_url = db_url
# Set default base URL if not provided
if memory_llm_base_url is None:
if memory_llm_provider.lower() == "groq":
memory_llm_base_url = "https://api.groq.com/openai/v1"
elif memory_llm_provider.lower() == "ollama":
memory_llm_base_url = "http://localhost:11434/v1"
else:
memory_llm_base_url = ""
# Connection pool (will be created in initialize())
self._pool = None
self._initialized = False
self._pool_min_size = pool_min_size if pool_min_size is not None else config.db_pool_min_size
self._pool_max_size = pool_max_size if pool_max_size is not None else config.db_pool_max_size
self._db_command_timeout = db_command_timeout if db_command_timeout is not None else config.db_command_timeout
self._db_acquire_timeout = db_acquire_timeout if db_acquire_timeout is not None else config.db_acquire_timeout
self._run_migrations = run_migrations
self._retain_entity_lookup = config.retain_entity_lookup
# Webhook manager (will be created in initialize() after pool is ready)
self._webhook_manager = None
self._http_client: httpx.AsyncClient | None = None
# Initialize entity resolver (will be created in initialize())
self.entity_resolver = None
# Initialize embeddings (from env vars if not provided)
if embeddings is not None:
self.embeddings = embeddings
else:
self.embeddings = create_embeddings_from_env()
# Initialize query analyzer
if query_analyzer is not None:
self.query_analyzer = query_analyzer
else:
from .query_analyzer import DateparserQueryAnalyzer
self.query_analyzer = DateparserQueryAnalyzer()
# Initialize LLM configuration (default, used as fallback)
self._llm_config = LLMConfig(
provider=memory_llm_provider,
api_key=memory_llm_api_key,
base_url=memory_llm_base_url,
model=memory_llm_model,
)
# Store client and model for convenience (deprecated: use _llm_config.call() instead)
self._llm_client = self._llm_config._client
self._llm_model = self._llm_config.model
# Initialize per-operation LLM configs (fall back to default if not specified)
# Retain LLM config - for fact extraction (benefits from strong structured output)
retain_provider = retain_llm_provider or config.retain_llm_provider or memory_llm_provider
retain_api_key = retain_llm_api_key or config.retain_llm_api_key or memory_llm_api_key
retain_model = retain_llm_model or config.retain_llm_model or memory_llm_model
retain_base_url = retain_llm_base_url or config.retain_llm_base_url or memory_llm_base_url
# Apply provider-specific base URL defaults for retain
if retain_base_url is None:
if retain_provider.lower() == "groq":
retain_base_url = "https://api.groq.com/openai/v1"
elif retain_provider.lower() == "ollama":
retain_base_url = "http://localhost:11434/v1"
else:
retain_base_url = ""
self._retain_llm_config = LLMConfig(
provider=retain_provider,
api_key=retain_api_key,
base_url=retain_base_url,
model=retain_model,
)
# Reflect LLM config - for think/observe operations (can use lighter models)
reflect_provider = reflect_llm_provider or config.reflect_llm_provider or memory_llm_provider
reflect_api_key = reflect_llm_api_key or config.reflect_llm_api_key or memory_llm_api_key
reflect_model = reflect_llm_model or config.reflect_llm_model or memory_llm_model
reflect_base_url = reflect_llm_base_url or config.reflect_llm_base_url or memory_llm_base_url
# Apply provider-specific base URL defaults for reflect
if reflect_base_url is None:
if reflect_provider.lower() == "groq":
reflect_base_url = "https://api.groq.com/openai/v1"
elif reflect_provider.lower() == "ollama":
reflect_base_url = "http://localhost:11434/v1"
else:
reflect_base_url = ""
self._reflect_llm_config = LLMConfig(
provider=reflect_provider,
api_key=reflect_api_key,
base_url=reflect_base_url,
model=reflect_model,
)
# Consolidation LLM config - for mental model consolidation (can use efficient models)
consolidation_provider = consolidation_llm_provider or config.consolidation_llm_provider or memory_llm_provider
consolidation_api_key = consolidation_llm_api_key or config.consolidation_llm_api_key or memory_llm_api_key
consolidation_model = consolidation_llm_model or config.consolidation_llm_model or memory_llm_model
consolidation_base_url = consolidation_llm_base_url or config.consolidation_llm_base_url or memory_llm_base_url
# Apply provider-specific base URL defaults for consolidation
if consolidation_base_url is None:
if consolidation_provider.lower() == "groq":
consolidation_base_url = "https://api.groq.com/openai/v1"
elif consolidation_provider.lower() == "ollama":
consolidation_base_url = "http://localhost:11434/v1"
else:
consolidation_base_url = ""
self._consolidation_llm_config = LLMConfig(
provider=consolidation_provider,
api_key=consolidation_api_key,
base_url=consolidation_base_url,
model=consolidation_model,
)
# Initialize cross-encoder reranker (cached for performance)
self._cross_encoder_reranker = CrossEncoderReranker(cross_encoder=cross_encoder)
# Initialize task backend
# If no custom backend provided, use BrokerTaskBackend which stores tasks in PostgreSQL
# The pool_getter lambda will return the pool once it's initialized
self._task_backend = task_backend or BrokerTaskBackend(
pool_getter=lambda: self._pool,
schema_getter=get_current_schema,
)
# Backpressure mechanism: limit concurrent searches to prevent overwhelming the database
# Configurable via HINDSIGHT_API_RECALL_MAX_CONCURRENT (default: 50)
self._search_semaphore = asyncio.Semaphore(get_config().recall_max_concurrent)
# Backpressure for put operations: limit concurrent puts to prevent database contention
# Each put_batch holds a connection for the entire transaction, so we limit to 5
# concurrent puts to avoid connection pool exhaustion and reduce write contention
self._put_semaphore = asyncio.Semaphore(5)
# initialize encoding eagerly to avoid delaying the first time
_get_tiktoken_encoding()
# Store operation validator extension (optional)
self._operation_validator = operation_validator
# Store tenant extension (always set, use default if none provided)
if tenant_extension is None:
from ..extensions.builtin.tenant import DefaultTenantExtension
tenant_extension = DefaultTenantExtension(config={})
self._tenant_extension = tenant_extension
@property
def tenant_extension(self) -> "TenantExtension | None":
"""The configured tenant extension, if any."""
return self._tenant_extension
async def _validate_operation(self, validation_coro) -> "ValidationResult | None":
"""
Run validation if an operation validator is configured.
Args:
validation_coro: Coroutine that returns a ValidationResult
Returns:
The ValidationResult (may contain enrichment fields), or None if no validator.
Raises:
OperationValidationError: If validation fails
"""
if self._operation_validator is None:
return None
from hindsight_api.extensions import OperationValidationError, ValidationResult
result = await validation_coro
if not result.allowed:
raise OperationValidationError(result.reason or "Operation not allowed", result.status_code)
return result
async def _authenticate_tenant(self, request_context: "RequestContext | None") -> str:
"""
Authenticate tenant and set schema in context variable.
The schema is stored in a contextvar for async-safe, per-task isolation.
Use fq_table(table_name) to get fully-qualified table names.
Args:
request_context: The request context with API key. Required if tenant_extension is configured.
Returns:
Schema name that was set in the context.
Raises:
AuthenticationError: If authentication fails or request_context is missing when required.
"""
from hindsight_api.extensions import AuthenticationError
if request_context is None:
raise AuthenticationError("RequestContext is required")
# For internal/background operations (e.g., worker tasks), skip extension authentication.
# The task was already authenticated at submission time, and execute_task sets _current_schema
# from the task's _schema field.
if request_context.internal:
return _current_schema.get()
# For MCP requests already authenticated via MCP_AUTH_TOKEN, skip tenant re-validation.
# The MCP transport layer already verified the token; re-validating against the tenant
# extension would fail when MCP_AUTH_TOKEN and TENANT_API_KEY differ.
if request_context.mcp_authenticated:
return _current_schema.get()
# Authenticate through tenant extension (always set, may be default no-auth extension)
tenant_context = await self._tenant_extension.authenticate(request_context)
_current_schema.set(tenant_context.schema_name)
return tenant_context.schema_name
async def _handle_batch_retain(self, task_dict: dict[str, Any]):
"""
Handler for batch retain tasks.
Args:
task_dict: Dict with 'bank_id', 'contents', 'operation_id'
Raises:
ValueError: If bank_id is missing
Exception: Any exception from retain_batch_async (propagates to execute_task for retry)
"""
bank_id = task_dict.get("bank_id")
if not bank_id:
raise ValueError("bank_id is required for batch retain task")
contents = task_dict.get("contents", [])
document_tags = task_dict.get("document_tags")
operation_id = task_dict.get("operation_id") # For batch API crash recovery
strategy = task_dict.get("strategy")
logger.info(
f"[BATCH_RETAIN_TASK] Starting background batch retain for bank_id={bank_id}, {len(contents)} items, operation_id={operation_id}"
)
# Restore tenant_id/api_key_id from task payload so extensions
# (e.g., operation validators) can attribute the operation correctly.
# internal=True to skip extension auth (worker has no API key),
# user_initiated=True so extensions know this originated from a user request.
from hindsight_api.models import RequestContext
context = RequestContext(
internal=True,
user_initiated=True,
tenant_id=task_dict.get("_tenant_id"),
api_key_id=task_dict.get("_api_key_id"),
)
await self.retain_batch_async(
bank_id=bank_id,
contents=contents,
document_tags=document_tags,
request_context=context,
operation_id=operation_id,
strategy=strategy,
outbox_callback=self._build_retain_outbox_callback(
bank_id=bank_id,
contents=contents,
operation_id=operation_id,
schema=_current_schema.get(),
),
)
# If this retain was triggered by file conversion, update document with file metadata
file_metadata = task_dict.get("_file_metadata")
if file_metadata and len(contents) == 1:
doc_id = contents[0].get("document_id")
if doc_id:
pool = await self._get_pool()
async with acquire_with_retry(pool) as conn:
await conn.execute(
f"""
UPDATE {fq_table("documents")}
SET file_storage_key = $3,
file_original_name = $4,
file_content_type = $5,
updated_at = NOW()
WHERE id = $1 AND bank_id = $2
""",
doc_id,
bank_id,
file_metadata["file_storage_key"],
file_metadata["file_original_name"],
file_metadata["file_content_type"],
)
logger.info(f"[BATCH_RETAIN_TASK] Completed background batch retain for bank_id={bank_id}")
async def _handle_file_convert_retain(self, task_dict: dict[str, Any]):
"""
Handler for file conversion tasks.
Converts a file to markdown, then submits a separate async retain operation
and marks this conversion as completed — all in a single transaction.
This avoids holding a worker slot during the expensive retain pipeline.
Args:
task_dict: Dict with 'bank_id', 'storage_key', 'parser', etc.
Raises:
ValueError: If required fields are missing
Exception: Any exception from conversion (includes filename in error)
"""
bank_id = task_dict.get("bank_id")
storage_key = task_dict.get("storage_key")
document_id = task_dict.get("document_id")
operation_id = task_dict.get("operation_id")
filename = task_dict.get("original_filename", "unknown")
if not all([bank_id, storage_key, document_id]):
raise ValueError("bank_id, storage_key, and document_id are required for file_convert_retain task")
logger.info(f"[FILE_CONVERT_RETAIN] Starting for bank_id={bank_id}, document_id={document_id}, file={filename}")
try:
# Retrieve file from storage
file_data = await self._file_storage.retrieve(storage_key)
# Convert to markdown using the ordered fallback chain stored in the task payload.
# task_dict["parser"] is always a list[str] set at submission time.
parser_chain: list[str] = task_dict.get("parser") or []
if not parser_chain:
raise ValueError("No parser chain defined for file_convert_retain task")
convert_result = await self._parser_registry.convert_with_fallback(
parsers=parser_chain,
file_data=file_data,
filename=filename,
content_type=task_dict.get("content_type"),
)
markdown_content = sanitize_llm_output(convert_result.content) or ""
winning_parser = convert_result.parser_name
except Exception as e:
# Re-raise with filename context for better error reporting
error_msg = f"Failed to parse file '{filename}': {str(e)}"
logger.error(f"[FILE_CONVERT_RETAIN] {error_msg}")
raise RuntimeError(error_msg) from e
logger.info(
f"[FILE_CONVERT_RETAIN] Converted file for bank_id={bank_id}, "
f"document_id={document_id}, {len(markdown_content)} chars. Submitting retain task."
)
# Fire file conversion hook (e.g., for Iris billing)
if self._operation_validator:
try:
from hindsight_api.extensions.operation_validator import FileConvertResult
from hindsight_api.models import RequestContext
convert_context = RequestContext(
internal=True,
user_initiated=True,
tenant_id=task_dict.get("_tenant_id"),
api_key_id=task_dict.get("_api_key_id"),
)
await self._operation_validator.on_file_convert_complete(
FileConvertResult(
bank_id=bank_id,
parser_name=winning_parser,
filename=filename,
output_chars=len(markdown_content),
output_text=markdown_content,
request_context=convert_context,
)
)
except Exception as e:
logger.warning(f"[FILE_CONVERT_RETAIN] on_file_convert_complete hook failed: {e}")
# Build retain task payload
retain_contents = [
{
"content": markdown_content,
"document_id": document_id,
"context": task_dict.get("context"),
"metadata": task_dict.get("metadata", {}),
"tags": task_dict.get("tags", []),
"timestamp": task_dict.get("timestamp"),
}
]
document_tags = task_dict.get("document_tags")
retain_task_payload: dict[str, Any] = {"contents": retain_contents}
if document_tags:
retain_task_payload["document_tags"] = document_tags
if task_dict.get("strategy"):
retain_task_payload["strategy"] = task_dict["strategy"]
# Pass tenant/api_key context through to retain task
if task_dict.get("_tenant_id"):
retain_task_payload["_tenant_id"] = task_dict["_tenant_id"]
if task_dict.get("_api_key_id"):
retain_task_payload["_api_key_id"] = task_dict["_api_key_id"]
# File metadata to attach after retain creates the document
retain_task_payload["_file_metadata"] = {
"file_storage_key": storage_key,
"file_original_name": task_dict["original_filename"],
"file_content_type": task_dict["content_type"],
}
# In one transaction: create the retain async operation AND mark this conversion as completed
retain_operation_id = uuid.uuid4()
pool = await self._get_pool()
async with acquire_with_retry(pool) as conn:
async with conn.transaction():
# Create the retain operation record
await conn.execute(
f"""
INSERT INTO {fq_table("async_operations")}
(operation_id, bank_id, operation_type, result_metadata, status)
VALUES ($1, $2, $3, $4, $5)
""",
retain_operation_id,
bank_id,
"retain",
json.dumps({}),
"pending",
)
# Mark this file_convert_retain operation as completed
if operation_id:
await conn.execute(
f"""
UPDATE {fq_table("async_operations")}
SET status = 'completed', updated_at = NOW(), completed_at = NOW()
WHERE operation_id = $1
""",
uuid.UUID(operation_id),
)
# Submit the retain task to the task backend (outside the transaction)
full_retain_payload = {
"type": "batch_retain",
"operation_id": str(retain_operation_id),
"bank_id": bank_id,
**retain_task_payload,
}
await self._task_backend.submit_task(full_retain_payload)
logger.info(
f"[FILE_CONVERT_RETAIN] Completed conversion for bank_id={bank_id}, "
f"document_id={document_id}. Retain task submitted as operation {retain_operation_id}"
)
# Delete file bytes from storage if configured (saves storage costs)
from ..config import get_config
config = get_config()
if config.file_delete_after_retain:
try:
await self._file_storage.delete(storage_key)
logger.info(f"[FILE_CONVERT_RETAIN] Deleted file bytes for {storage_key} (conversion completed)")
except Exception as e:
# Non-fatal - log and continue
logger.warning(f"[FILE_CONVERT_RETAIN] Failed to delete file {storage_key}: {e}")
async def _handle_consolidation(self, task_dict: dict[str, Any]):
"""
Handler for consolidation tasks.
Consolidates new memories into mental models for a bank.
Args:
task_dict: Dict with 'bank_id'
Raises:
ValueError: If bank_id is missing
Exception: Any exception from consolidation (propagates to execute_task for retry)
"""
bank_id = task_dict.get("bank_id")
if not bank_id:
raise ValueError("bank_id is required for consolidation task")
from hindsight_api.models import RequestContext
from .consolidation import run_consolidation_job
# Restore tenant_id/api_key_id from task payload so downstream operations
# (e.g., mental model refreshes) can attribute usage to the correct org.
internal_context = RequestContext(
internal=True,
tenant_id=task_dict.get("_tenant_id"),
api_key_id=task_dict.get("_api_key_id"),
)
result = await run_consolidation_job(
memory_engine=self,
bank_id=bank_id,
request_context=internal_context,
operation_id=task_dict.get("operation_id"),
)
logger.info(f"[CONSOLIDATION] bank={bank_id} completed: {result.get('memories_processed', 0)} processed")
return result
async def _handle_refresh_mental_model(self, task_dict: dict[str, Any]):
"""
Handler for refresh_mental_model tasks.
Re-runs the source query through reflect and updates the mental model content.
Args:
task_dict: Dict with 'bank_id', 'mental_model_id', 'operation_id'
Raises:
ValueError: If required fields are missing
Exception: Any exception from reflect/update (propagates to execute_task for retry)
"""
bank_id = task_dict.get("bank_id")
mental_model_id = task_dict.get("mental_model_id")
if not bank_id or not mental_model_id:
raise ValueError("bank_id and mental_model_id are required for refresh_mental_model task")
logger.info(f"[REFRESH_MENTAL_MODEL_TASK] Starting for bank_id={bank_id}, mental_model_id={mental_model_id}")
from hindsight_api.models import RequestContext
# Restore tenant_id/api_key_id from task payload so extensions can
# attribute the mental_model_refresh operation to the correct org.
internal_context = RequestContext(
internal=True,
tenant_id=task_dict.get("_tenant_id"),
api_key_id=task_dict.get("_api_key_id"),
)
# Get the current mental model to get source_query
mental_model = await self.get_mental_model(bank_id, mental_model_id, request_context=internal_context)
if not mental_model:
raise ValueError(f"Mental model {mental_model_id} not found in bank {bank_id}")
source_query = mental_model["source_query"]
# SECURITY: If the mental model has tags, pass them to reflect with "all_strict" matching
# to ensure it can only access other mental models/memories with the SAME tags.
# This prevents cross-tenant/cross-user information leakage by excluding untagged content.
tags = mental_model.get("tags")
tags_match = "all_strict" if tags else "any"
# Read reflect options from trigger (if stored)
trigger_data = mental_model.get("trigger") or {}
fact_types = trigger_data.get("fact_types")
exclude_mental_models = trigger_data.get("exclude_mental_models", False)
stored_exclude_ids: list[str] = trigger_data.get("exclude_mental_model_ids") or []
# Run reflect to generate new content, excluding the mental model being refreshed
# Always add self to excluded IDs to prevent circular reference
reflect_result = await self.reflect_async(
bank_id=bank_id,
query=source_query,
request_context=internal_context,
tags=tags,
tags_match=tags_match,
fact_types=fact_types,
exclude_mental_models=exclude_mental_models,
exclude_mental_model_ids=list({*stored_exclude_ids, mental_model_id}),
)
generated_content = reflect_result.text or "No content generated"
# Build reflect_response payload to store
# based_on contains MemoryFact objects for most types, but plain dicts for directives
based_on_serialized: dict[str, list[dict[str, Any]]] = {}
for fact_type, facts in reflect_result.based_on.items():
serialized_facts = []
for fact in facts:
if isinstance(fact, dict):
# Plain dict (e.g., directives with id, name, content)
serialized_facts.append(
{
"id": str(fact["id"]),
"text": fact.get("text", fact.get("content", fact.get("name", ""))),
"type": fact_type,
}
)
else:
# MemoryFact object with .id and .text attributes
serialized_facts.append(
{
"id": str(fact.id),
"text": fact.text,
"type": fact_type,
}
)
based_on_serialized[fact_type] = serialized_facts
reflect_response = {
"text": reflect_result.text,
"based_on": based_on_serialized,
}
# Update the mental model with the generated content and reflect_response
await self.update_mental_model(
bank_id=bank_id,
mental_model_id=mental_model_id,
content=generated_content,
reflect_response=reflect_response,
request_context=internal_context,
)
# Call post-operation hook if validator is configured
if self._operation_validator:
from hindsight_api.extensions.operation_validator import MentalModelRefreshResult
# Count facts and mental models from based_on
facts_used = 0
mental_models_used = 0
if reflect_result.based_on:
for fact_type, facts in reflect_result.based_on.items():
if facts:
if fact_type == "mental_models":
mental_models_used += len(facts)
else:
facts_used += len(facts)
# Estimate tokens
query_tokens = len(source_query) // 4 if source_query else 0
output_tokens = len(generated_content) // 4 if generated_content else 0
context_tokens = 0 # refresh doesn't use additional context
result_ctx = MentalModelRefreshResult(
bank_id=bank_id,
mental_model_id=mental_model_id,
request_context=internal_context,
query_tokens=query_tokens,
output_tokens=output_tokens,
context_tokens=context_tokens,
facts_used=facts_used,
mental_models_used=mental_models_used,
success=True,
)
try:
await self._operation_validator.on_mental_model_refresh_complete(result_ctx)
except Exception as hook_err:
logger.warning(f"Post-mental-model-refresh hook error (non-fatal): {hook_err}")
logger.info(f"[REFRESH_MENTAL_MODEL_TASK] Completed for bank_id={bank_id}, mental_model_id={mental_model_id}")
async def execute_task(self, task_dict: dict[str, Any]):
"""
Execute a task by routing it to the appropriate handler.
This method is called by the task backend to execute tasks.
It receives a plain dict that can be serialized and sent over the network.
Args:
task_dict: Task dictionary with 'type' key and other payload data
Example: {'type': 'batch_retain', 'bank_id': '...', 'contents': [...]}
"""
task_type = task_dict.get("type")
operation_id = task_dict.get("operation_id")