|
16 | 16 | import time |
17 | 17 | import uuid |
18 | 18 | from collections.abc import Awaitable, Callable |
| 19 | +from dataclasses import dataclass |
19 | 20 | from datetime import UTC, datetime, timedelta, timezone |
20 | 21 | from typing import TYPE_CHECKING, Any |
21 | 22 |
|
@@ -226,6 +227,42 @@ def _get_tiktoken_encoding(): |
226 | 227 | return _TIKTOKEN_ENCODING |
227 | 228 |
|
228 | 229 |
|
| 230 | +@dataclass(frozen=True) |
| 231 | +class RefreshTagFiltering: |
| 232 | + """Resolved tag filtering parameters for mental model refresh.""" |
| 233 | + |
| 234 | + tags: list[str] | None |
| 235 | + tags_match: TagsMatch |
| 236 | + tag_groups: list[TagGroup] | None |
| 237 | + |
| 238 | + |
| 239 | +def _resolve_refresh_tag_filtering( |
| 240 | + model_tags: list[str] | None, |
| 241 | + trigger_data: dict[str, Any], |
| 242 | +) -> RefreshTagFiltering: |
| 243 | + """Resolve tag filtering parameters for mental model refresh. |
| 244 | +
|
| 245 | + Takes raw trigger dict from DB (JSONB with no fixed schema guarantee) |
| 246 | + and resolves the tag filtering to use during reflect. |
| 247 | +
|
| 248 | + Priority: |
| 249 | + - If trigger has tag_groups, use those (overrides flat tags entirely) |
| 250 | + - If trigger has tags_match, use model's tags with that match mode |
| 251 | + - Otherwise default to all_strict when tags present (security isolation) |
| 252 | + """ |
| 253 | + trigger_tag_groups = trigger_data.get("tag_groups") |
| 254 | + if trigger_tag_groups is not None: |
| 255 | + from pydantic import TypeAdapter |
| 256 | + |
| 257 | + adapter = TypeAdapter(TagGroup) |
| 258 | + parsed = [adapter.validate_python(tg) for tg in trigger_tag_groups] |
| 259 | + return RefreshTagFiltering(tags=None, tags_match="any", tag_groups=parsed) |
| 260 | + |
| 261 | + trigger_tags_match = trigger_data.get("tags_match") |
| 262 | + tags_match: TagsMatch = trigger_tags_match if trigger_tags_match else ("all_strict" if model_tags else "any") |
| 263 | + return RefreshTagFiltering(tags=model_tags, tags_match=tags_match, tag_groups=None) |
| 264 | + |
| 265 | + |
229 | 266 | class MemoryEngine(MemoryEngineInterface): |
230 | 267 | """ |
231 | 268 | Advanced memory system using temporal and semantic linking with PostgreSQL. |
@@ -908,26 +945,23 @@ async def _handle_refresh_mental_model(self, task_dict: dict[str, Any]): |
908 | 945 |
|
909 | 946 | source_query = mental_model["source_query"] |
910 | 947 |
|
911 | | - # SECURITY: If the mental model has tags, pass them to reflect with "all_strict" matching |
912 | | - # to ensure it can only access other mental models/memories with the SAME tags. |
913 | | - # This prevents cross-tenant/cross-user information leakage by excluding untagged content. |
914 | | - tags = mental_model.get("tags") |
915 | | - tags_match = "all_strict" if tags else "any" |
916 | | - |
917 | 948 | # Read reflect options from trigger (if stored) |
918 | 949 | trigger_data = mental_model.get("trigger") or {} |
919 | 950 | fact_types = trigger_data.get("fact_types") |
920 | 951 | exclude_mental_models = trigger_data.get("exclude_mental_models", False) |
921 | 952 | stored_exclude_ids: list[str] = trigger_data.get("exclude_mental_model_ids") or [] |
922 | 953 |
|
| 954 | + tag_filtering = _resolve_refresh_tag_filtering(mental_model.get("tags"), trigger_data) |
| 955 | + |
923 | 956 | # Run reflect to generate new content, excluding the mental model being refreshed |
924 | 957 | # Always add self to excluded IDs to prevent circular reference |
925 | 958 | reflect_result = await self.reflect_async( |
926 | 959 | bank_id=bank_id, |
927 | 960 | query=source_query, |
928 | 961 | request_context=internal_context, |
929 | | - tags=tags, |
930 | | - tags_match=tags_match, |
| 962 | + tags=tag_filtering.tags, |
| 963 | + tags_match=tag_filtering.tags_match, |
| 964 | + tag_groups=tag_filtering.tag_groups, |
931 | 965 | fact_types=fact_types, |
932 | 966 | exclude_mental_models=exclude_mental_models, |
933 | 967 | exclude_mental_model_ids=list({*stored_exclude_ids, mental_model_id}), |
@@ -6581,26 +6615,23 @@ async def refresh_mental_model( |
6581 | 6615 |
|
6582 | 6616 | # Create parent span for mental model refresh operation |
6583 | 6617 | with create_operation_span("mental_model_refresh", bank_id): |
6584 | | - # SECURITY: If the mental model has tags, pass them to reflect with "all_strict" matching |
6585 | | - # to ensure it can only access other mental models/memories with the SAME tags. |
6586 | | - # This prevents cross-tenant/cross-user information leakage by excluding untagged content. |
6587 | | - tags = mental_model.get("tags") |
6588 | | - tags_match = "all_strict" if tags else "any" |
6589 | | - |
6590 | 6618 | # Read reflect options from trigger (if stored) |
6591 | 6619 | trigger_data = mental_model.get("trigger") or {} |
6592 | 6620 | fact_types = trigger_data.get("fact_types") |
6593 | 6621 | exclude_mental_models = trigger_data.get("exclude_mental_models", False) |
6594 | 6622 | stored_exclude_ids: list[str] = trigger_data.get("exclude_mental_model_ids") or [] |
6595 | 6623 |
|
| 6624 | + tag_filtering = _resolve_refresh_tag_filtering(mental_model.get("tags"), trigger_data) |
| 6625 | + |
6596 | 6626 | # Run reflect with the source query, excluding the mental model being refreshed |
6597 | 6627 | # Skip creating a nested "hindsight.reflect" span since we already have "hindsight.mental_model_refresh" |
6598 | 6628 | reflect_result = await self.reflect_async( |
6599 | 6629 | bank_id=bank_id, |
6600 | 6630 | query=mental_model["source_query"], |
6601 | 6631 | request_context=request_context, |
6602 | | - tags=tags, |
6603 | | - tags_match=tags_match, |
| 6632 | + tags=tag_filtering.tags, |
| 6633 | + tags_match=tag_filtering.tags_match, |
| 6634 | + tag_groups=tag_filtering.tag_groups, |
6604 | 6635 | fact_types=fact_types, |
6605 | 6636 | exclude_mental_models=exclude_mental_models, |
6606 | 6637 | exclude_mental_model_ids=list({*stored_exclude_ids, mental_model_id}), |
|
0 commit comments