From 91933699c998441b8ff1b0c067aac9322c9f9192 Mon Sep 17 00:00:00 2001 From: MarioAlessandroNapoli Date: Wed, 18 Feb 2026 19:26:36 +0100 Subject: [PATCH 1/4] feat: import checkpoint history via supersteps MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the current create+update_state import (which loses checkpoint history) with supersteps-based import that preserves time-travel. Two conversion strategies: - Direct: uses metadata.writes when available (local checkpointer) - Delta: computes state diffs between consecutive checkpoints (Cloud API) Delta approach skips middleware no-ops, only generating supersteps for state-changing steps (e.g. 100 checkpoints → 30 supersteps). Includes automatic fallback to legacy import if the target API doesn't support supersteps (400/422 → switches for all remaining threads). Validation now optionally compares checkpoint history counts for a sample thread (enabled with --test-single in full migration). --- langgraph_export/client.py | 35 +++++ langgraph_export/migrator.py | 240 ++++++++++++++++++++++++++++++++--- migrate_threads.py | 40 ++++-- requirements.txt | 2 +- 4 files changed, 289 insertions(+), 28 deletions(-) diff --git a/langgraph_export/client.py b/langgraph_export/client.py index 8e3721a..c976796 100644 --- a/langgraph_export/client.py +++ b/langgraph_export/client.py @@ -105,6 +105,41 @@ async def create_thread( ) return dict(thread) if thread else {} + async def create_thread_with_history( + self, + thread_id: str, + metadata: Optional[Dict[str, Any]] = None, + supersteps: Optional[List[Dict[str, Any]]] = None, + if_exists: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Create a new thread with pre-populated checkpoint history. + + Uses the supersteps parameter to replay state updates, + reconstructing the full checkpoint chain on the target. + + Args: + thread_id: Thread ID to use + metadata: Optional metadata to attach + supersteps: List of supersteps, each containing updates + with values and as_node + if_exists: Conflict behavior ('raise' or 'do_nothing') + + Returns: + Created thread dictionary + """ + kwargs: Dict[str, Any] = { + "thread_id": thread_id, + "metadata": metadata or {}, + } + if supersteps: + kwargs["supersteps"] = supersteps + if if_exists: + kwargs["if_exists"] = if_exists + + thread = await self._client.threads.create(**kwargs) + return dict(thread) if thread else {} + async def update_thread_state( self, thread_id: str, diff --git a/langgraph_export/migrator.py b/langgraph_export/migrator.py index 85f231a..509dc0a 100644 --- a/langgraph_export/migrator.py +++ b/langgraph_export/migrator.py @@ -229,6 +229,168 @@ async def export_threads( ) return stats + @staticmethod + def _compute_values_delta( + prev_values: Dict[str, Any], + curr_values: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Compute the delta between two checkpoint states. + + For messages (add_messages reducer): extracts only new messages by ID. + For lists of dicts with 'id' field: extracts only new items by ID. + For all other fields: includes only if changed. + """ + delta: Dict[str, Any] = {} + + for key, curr_val in curr_values.items(): + prev_val = prev_values.get(key) + + if curr_val == prev_val: + continue + + if key == "messages": + # add_messages reducer: diff by message ID + prev_ids = { + m.get("id") for m in (prev_val or []) if isinstance(m, dict) + } + new_msgs = [ + m for m in (curr_val or []) + if isinstance(m, dict) and m.get("id") not in prev_ids + ] + if new_msgs: + delta["messages"] = new_msgs + + elif ( + isinstance(curr_val, list) + and curr_val + and isinstance(curr_val[0], dict) + and "id" in curr_val[0] + ): + # List of dicts with ID → dedup-aware delta + prev_ids = { + item.get("id") + for item in (prev_val or []) + if isinstance(item, dict) + } + new_items = [ + item for item in curr_val + if isinstance(item, dict) and item.get("id") not in prev_ids + ] + if new_items: + delta[key] = new_items + + else: + # Scalar or non-ID list → include if changed + delta[key] = curr_val + + return delta + + @staticmethod + def _history_to_supersteps( + history: List[Dict[str, Any]], + ) -> List[Dict[str, Any]]: + """ + Convert checkpoint history to supersteps format for threads.create(). + + History from get_history() is most-recent-first. + Supersteps must be chronological (oldest-first). + + Strategy: + - If metadata.writes is available (local checkpointer), use it directly. + - Otherwise (Cloud API), compute deltas between consecutive states. + Only state-changing steps produce supersteps; middleware no-ops are skipped. + """ + if not history: + return [] + + chronological = list(reversed(history)) + supersteps: List[Dict[str, Any]] = [] + + # Check if writes data is available (first non-empty checkpoint) + has_writes = any( + isinstance((h.get("metadata") or {}).get("writes"), dict) + for h in history + ) + + if has_writes: + # Direct approach: use metadata.writes + for state in chronological: + metadata = state.get("metadata") or {} + writes = metadata.get("writes") + if not writes or not isinstance(writes, dict): + continue + updates = [] + for node_name, node_values in writes.items(): + if node_values is None: + continue + updates.append({"values": node_values, "as_node": node_name}) + if updates: + supersteps.append({"updates": updates}) + else: + # Delta approach: compute diffs between consecutive states + prev_values: Dict[str, Any] = {} + for i, state in enumerate(chronological): + curr_values = state.get("values") or {} + + if i == 0: + # Include initial state as first superstep + if curr_values: + supersteps.append({ + "updates": [{ + "values": curr_values, + "as_node": "__start__", + }], + }) + prev_values = curr_values + continue + + # as_node = what the previous checkpoint's next field says ran + prev_next = chronological[i - 1].get("next", []) + as_node = prev_next[0] if prev_next else "__unknown__" + + delta = ThreadMigrator._compute_values_delta(prev_values, curr_values) + + if delta: + supersteps.append({ + "updates": [{"values": delta, "as_node": as_node}], + }) + + prev_values = curr_values + + return supersteps + + async def _import_thread_with_history( + self, + thread: ThreadData, + ) -> int: + """ + Import a single thread using supersteps to preserve checkpoint history. + + Returns the number of checkpoints imported. + Raises on failure so caller can handle fallback. + """ + supersteps = self._history_to_supersteps(thread.history) + + await self._target_client.create_thread_with_history( + thread_id=thread.thread_id, + metadata=thread.metadata, + supersteps=supersteps if supersteps else None, + ) + return len(supersteps) + + async def _import_thread_legacy(self, thread: ThreadData) -> None: + """Fallback: create thread + update_state (no history preservation).""" + await self._target_client.create_thread( + thread_id=thread.thread_id, + metadata=thread.metadata, + ) + if thread.values: + await self._target_client.update_thread_state( + thread_id=thread.thread_id, + values=thread.values, + ) + async def import_threads( self, threads: List[ThreadData], @@ -238,43 +400,61 @@ async def import_threads( """ Import threads to target deployment. + Tries supersteps-based import first to preserve checkpoint history. + Falls back to legacy create+update_state if the target API + doesn't support supersteps. + Args: threads: Threads to import dry_run: If True, don't actually create threads progress_callback: Optional callback for progress updates Returns: - Dictionary with created, skipped, failed counts + Dictionary with created, skipped, failed, checkpoints counts """ if not self._target_client: raise RuntimeError("Target URL not configured") - results = {"created": 0, "skipped": 0, "failed": 0} + results = {"created": 0, "skipped": 0, "failed": 0, "checkpoints": 0} + use_legacy = False for i, thread in enumerate(threads): try: if dry_run: + supersteps = self._history_to_supersteps(thread.history) results["skipped"] += 1 if progress_callback: - progress_callback(i + 1, f"[DRY-RUN] Would create {thread.thread_id[:8]}...") + progress_callback( + i + 1, + f"[DRY-RUN] Would create {thread.thread_id[:8]}... " + f"({len(supersteps)} supersteps)" + ) continue - # Create thread - await self._target_client.create_thread( - thread_id=thread.thread_id, - metadata=thread.metadata, - ) - - # Update state if available - if thread.values: - await self._target_client.update_thread_state( - thread_id=thread.thread_id, - values=thread.values, - ) + if not use_legacy: + try: + checkpoints = await self._import_thread_with_history(thread) + results["checkpoints"] += checkpoints + except APIStatusError as e: + # supersteps not supported — switch to legacy for all remaining + if e.status_code in (400, 422): + use_legacy = True + if progress_callback: + progress_callback( + i + 1, + "supersteps not supported, falling back to legacy import" + ) + await self._import_thread_legacy(thread) + else: + raise + else: + await self._import_thread_legacy(thread) results["created"] += 1 if progress_callback: - progress_callback(i + 1, f"Created {thread.thread_id[:8]}...") + cp_count = len(thread.history) + mode = "legacy" if use_legacy else f"{cp_count} checkpoints" + progress_callback(i + 1, f"Created {thread.thread_id[:8]}... ({mode})") except ConflictError: results["skipped"] += 1 @@ -295,12 +475,19 @@ async def import_threads( return results - async def validate_migration(self) -> Dict[str, int]: + async def validate_migration( + self, + check_history: bool = False, + sample_thread_id: Optional[str] = None, + ) -> Dict[str, Any]: """ Compare thread counts between source and target. + Optionally validates checkpoint history for a sample thread. + Returns: - Dictionary with source_count, target_count + Dictionary with source_count, target_count, and optionally + history_source/history_target for the sample thread. """ source_count = 0 target_count = 0 @@ -313,8 +500,23 @@ async def validate_migration(self) -> Dict[str, int]: threads = await self._target_client.search_threads(limit=10000) target_count = len(threads) - return { + result: Dict[str, Any] = { "source_count": source_count, "target_count": target_count, "difference": source_count - target_count, } + + # Validate checkpoint history for a sample thread + if check_history and sample_thread_id: + if self._source_client: + src_history = await self._source_client.get_thread_history( + sample_thread_id, limit=1000 + ) + result["history_source"] = len(src_history) + if self._target_client: + tgt_history = await self._target_client.get_thread_history( + sample_thread_id, limit=1000 + ) + result["history_target"] = len(tgt_history) + + return result diff --git a/migrate_threads.py b/migrate_threads.py index 45fd63e..cc6bd0e 100644 --- a/migrate_threads.py +++ b/migrate_threads.py @@ -160,7 +160,8 @@ def update_progress(count: int, message: str) -> None: table.add_row("Created", str(results["created"])) table.add_row("Skipped (exists)", str(results["skipped"])) table.add_row("Failed", str(results["failed"])) - table.add_row("Total", str(len(threads))) + table.add_row("Checkpoints", str(results.get("checkpoints", 0))) + table.add_row("Total threads", str(len(threads))) console.print(table) @@ -245,6 +246,7 @@ def update_progress2(count: int, message: str) -> None: table.add_row("Created", str(results["created"])) table.add_row("Skipped (exists)", str(results["skipped"])) table.add_row("Failed", str(results["failed"])) + table.add_row("Checkpoints", str(results.get("checkpoints", 0))) console.print(table) @@ -254,22 +256,44 @@ def update_progress2(count: int, message: str) -> None: border_style="blue", )) - validation = await migrator.validate_migration() + # For test-single, also validate checkpoint history + sample_id = threads[0].thread_id if test_single and threads else None + validation = await migrator.validate_migration( + check_history=bool(sample_id), + sample_thread_id=sample_id, + ) table = Table(title="Validation") - table.add_column("Deployment", style="cyan") - table.add_column("Threads", justify="right", style="magenta") - - table.add_row("Source", str(validation["source_count"])) - table.add_row("Target", str(validation["target_count"])) + table.add_column("Metric", style="cyan") + table.add_column("Source", justify="right", style="magenta") + table.add_column("Target", justify="right", style="magenta") + + table.add_row( + "Threads", + str(validation["source_count"]), + str(validation["target_count"]), + ) + if "history_source" in validation: + table.add_row( + f"Checkpoints ({sample_id[:8]}...)", + str(validation["history_source"]), + str(validation["history_target"]), + ) console.print(table) if validation["difference"] <= 0: - console.print("[green]✓ Migration validated successfully![/green]") + console.print("[green]✓ Thread count validated![/green]") else: console.print(f"[yellow]⚠ Target has {validation['difference']} fewer threads[/yellow]") + if "history_source" in validation: + if validation["history_source"] == validation["history_target"]: + console.print("[green]✓ Checkpoint history validated![/green]") + else: + diff = validation["history_source"] - validation["history_target"] + console.print(f"[yellow]⚠ History mismatch: {diff} checkpoints missing[/yellow]") + async def run_validate( source_url: str, diff --git a/requirements.txt b/requirements.txt index 81317e6..f51021f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -langgraph-sdk>=0.1.66 +langgraph-sdk>=0.3.4 rich>=13.9.0 python-dotenv>=1.0.0 sqlalchemy[asyncio]>=2.0.0 # Optional: for PostgreSQL export From 23fe64f9227f05f9425dafabae29192693a03085 Mon Sep 17 00:00:00 2001 From: MarioAlessandroNapoli Date: Wed, 18 Feb 2026 19:39:52 +0100 Subject: [PATCH 2/4] feat: support separate source/target API keys for cross-org migration (AE-260) Add --source-api-key and --target-api-key CLI flags with LANGSMITH_SOURCE_API_KEY and LANGSMITH_TARGET_API_KEY env vars. Backward compatible: LANGSMITH_API_KEY used as fallback for both. --- .env.example | 7 ++- README.md | 19 +++++++- langgraph_export/migrator.py | 15 ++++-- migrate_threads.py | 95 +++++++++++++++++++++++++++--------- 4 files changed, 106 insertions(+), 30 deletions(-) diff --git a/.env.example b/.env.example index a187321..ac56b4b 100644 --- a/.env.example +++ b/.env.example @@ -1,8 +1,13 @@ -# LangSmith API Key (required) +# LangSmith API Key — shared fallback for source and target # Get your key from: https://smith.langchain.com/ → Settings → API Keys # Use a Service Key (lsv2_sk_...) for deployment access LANGSMITH_API_KEY=lsv2_sk_your_api_key_here +# For cross-org migration: separate keys for source and target +# These override LANGSMITH_API_KEY when set +# LANGSMITH_SOURCE_API_KEY=lsv2_sk_source_org_key +# LANGSMITH_TARGET_API_KEY=lsv2_sk_target_org_key + # PostgreSQL Database URL (optional, for --export-postgres) # Example: postgresql://user:password@localhost:5432/dbname DATABASE_URL= diff --git a/README.md b/README.md index c3662fb..29da986 100644 --- a/README.md +++ b/README.md @@ -59,9 +59,13 @@ cp .env.example .env Edit `.env` with your credentials: ```bash -# Required: LangSmith API key +# Required: LangSmith API key (shared fallback for source and target) LANGSMITH_API_KEY=lsv2_sk_your_api_key_here +# For cross-org migration: separate keys override the shared key +# LANGSMITH_SOURCE_API_KEY=lsv2_sk_source_org_key +# LANGSMITH_TARGET_API_KEY=lsv2_sk_target_org_key + # Optional: PostgreSQL connection URL (for --export-postgres) DATABASE_URL=postgresql://user:password@localhost:5432/dbname ``` @@ -99,10 +103,19 @@ This creates two tables: Transfer all threads from one deployment to another: ```bash +# Same org (shared API key) python migrate_threads.py \ --source-url https://my-prod.langgraph.app \ --target-url https://my-dev.langgraph.app \ --full + +# Cross-org (separate API keys) +python migrate_threads.py \ + --source-url https://org1.langgraph.app \ + --target-url https://org2.langgraph.app \ + --source-api-key lsv2_sk_source... \ + --target-api-key lsv2_sk_target... \ + --full ``` ### Import from JSON @@ -132,7 +145,9 @@ python migrate_threads.py \ |----------|-------------| | `--source-url` | Source LangGraph Cloud deployment URL | | `--target-url` | Target LangGraph Cloud deployment URL | -| `--api-key` | LangSmith API key (or set in `.env`) | +| `--api-key` | Shared API key fallback (or `LANGSMITH_API_KEY`) | +| `--source-api-key` | Source API key for cross-org (or `LANGSMITH_SOURCE_API_KEY`) | +| `--target-api-key` | Target API key for cross-org (or `LANGSMITH_TARGET_API_KEY`) | | `--database-url` | PostgreSQL URL (or set in `.env`) | | `--export-json FILE` | Export threads to JSON file | | `--export-postgres` | Export threads to PostgreSQL database | diff --git a/langgraph_export/migrator.py b/langgraph_export/migrator.py index 509dc0a..cab04cd 100644 --- a/langgraph_export/migrator.py +++ b/langgraph_export/migrator.py @@ -28,6 +28,8 @@ def __init__( source_url: Optional[str] = None, target_url: Optional[str] = None, api_key: str = "", + source_api_key: Optional[str] = None, + target_api_key: Optional[str] = None, rate_limit_delay: float = 0.2, ): """ @@ -36,12 +38,15 @@ def __init__( Args: source_url: Source LangGraph deployment URL target_url: Target LangGraph deployment URL (for migration) - api_key: LangSmith API key + api_key: Shared API key (fallback for source/target) + source_api_key: API key for source deployment (overrides api_key) + target_api_key: API key for target deployment (overrides api_key) rate_limit_delay: Delay between API calls (seconds) """ self.source_url = source_url self.target_url = target_url - self.api_key = api_key + self.source_api_key = source_api_key or api_key + self.target_api_key = target_api_key or api_key self.rate_limit_delay = rate_limit_delay self._source_client: Optional[LangGraphClient] = None @@ -58,11 +63,11 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: await self.close() async def connect(self) -> None: - """Initialize clients.""" + """Initialize clients with their respective API keys.""" if self.source_url: - self._source_client = LangGraphClient(self.source_url, self.api_key) + self._source_client = LangGraphClient(self.source_url, self.source_api_key) if self.target_url: - self._target_client = LangGraphClient(self.target_url, self.api_key) + self._target_client = LangGraphClient(self.target_url, self.target_api_key) async def close(self) -> None: """Close all clients and exporters.""" diff --git a/migrate_threads.py b/migrate_threads.py index cc6bd0e..6480307 100644 --- a/migrate_threads.py +++ b/migrate_threads.py @@ -33,7 +33,7 @@ async def run_export_json( source_url: str, - api_key: str, + source_api_key: str, output_file: str, test_single: bool = False, ) -> None: @@ -43,7 +43,7 @@ async def run_export_json( border_style="cyan", )) - async with ThreadMigrator(source_url=source_url, api_key=api_key) as migrator: + async with ThreadMigrator(source_url=source_url, source_api_key=source_api_key) as migrator: json_exporter = migrator.add_json_exporter(output_file) with Progress( @@ -70,7 +70,7 @@ def update_progress(count: int, message: str) -> None: async def run_export_postgres( source_url: str, - api_key: str, + source_api_key: str, database_url: str, output_file: str, test_single: bool = False, @@ -81,7 +81,7 @@ async def run_export_postgres( border_style="cyan", )) - async with ThreadMigrator(source_url=source_url, api_key=api_key) as migrator: + async with ThreadMigrator(source_url=source_url, source_api_key=source_api_key) as migrator: json_exporter = migrator.add_json_exporter(output_file) pg_exporter = migrator.add_postgres_exporter(database_url) @@ -112,7 +112,7 @@ def update_progress(count: int, message: str) -> None: async def run_import_json( target_url: str, - api_key: str, + target_api_key: str, input_file: str, dry_run: bool = False, ) -> None: @@ -133,7 +133,7 @@ async def run_import_json( if dry_run: console.print("\n[yellow]⚠ DRY-RUN MODE: No changes will be made[/yellow]") - async with ThreadMigrator(target_url=target_url, api_key=api_key) as migrator: + async with ThreadMigrator(target_url=target_url, target_api_key=target_api_key) as migrator: with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), @@ -169,7 +169,8 @@ def update_progress(count: int, message: str) -> None: async def run_full_migration( source_url: str, target_url: str, - api_key: str, + source_api_key: str, + target_api_key: str, backup_file: str, dry_run: bool = False, test_single: bool = False, @@ -184,7 +185,8 @@ async def run_full_migration( async with ThreadMigrator( source_url=source_url, target_url=target_url, - api_key=api_key, + source_api_key=source_api_key, + target_api_key=target_api_key, ) as migrator: json_exporter = migrator.add_json_exporter(backup_file) @@ -298,7 +300,8 @@ def update_progress2(count: int, message: str) -> None: async def run_validate( source_url: str, target_url: str, - api_key: str, + source_api_key: str, + target_api_key: str, ) -> None: """Validate migration by comparing thread counts.""" console.print(Panel.fit( @@ -309,7 +312,8 @@ async def run_validate( async with ThreadMigrator( source_url=source_url, target_url=target_url, - api_key=api_key, + source_api_key=source_api_key, + target_api_key=target_api_key, ) as migrator: validation = await migrator.validate_migration() @@ -343,9 +347,13 @@ def main() -> None: # Export to PostgreSQL python migrate_threads.py --source-url https://my.langgraph.app --export-postgres - # Full migration + # Full migration (same org) python migrate_threads.py --source-url https://prod.langgraph.app --target-url https://dev.langgraph.app --full + # Full migration (cross-org, separate API keys) + python migrate_threads.py --source-url https://org1.langgraph.app --target-url https://org2.langgraph.app \\ + --source-api-key lsv2_sk_... --target-api-key lsv2_sk_... --full + # Import from JSON python migrate_threads.py --target-url https://dev.langgraph.app --import-json backup.json """ @@ -353,7 +361,9 @@ def main() -> None: parser.add_argument("--source-url", help="Source LangGraph deployment URL") parser.add_argument("--target-url", help="Target LangGraph deployment URL") - parser.add_argument("--api-key", help="LangSmith API key (or set LANGSMITH_API_KEY)") + parser.add_argument("--api-key", help="Shared API key fallback (or set LANGSMITH_API_KEY)") + parser.add_argument("--source-api-key", help="Source API key (or set LANGSMITH_SOURCE_API_KEY)") + parser.add_argument("--target-api-key", help="Target API key (or set LANGSMITH_TARGET_API_KEY)") parser.add_argument("--database-url", help="PostgreSQL URL (or set DATABASE_URL)") parser.add_argument("--export-json", metavar="FILE", help="Export to JSON file") @@ -368,12 +378,29 @@ def main() -> None: args = parser.parse_args() - # Load from environment if not provided - api_key = args.api_key or os.getenv("LANGSMITH_API_KEY") + # Resolve API keys: CLI flag > specific env var > shared CLI flag > shared env var + shared_key = args.api_key or os.getenv("LANGSMITH_API_KEY") + source_api_key = ( + args.source_api_key + or os.getenv("LANGSMITH_SOURCE_API_KEY") + or shared_key + ) + target_api_key = ( + args.target_api_key + or os.getenv("LANGSMITH_TARGET_API_KEY") + or shared_key + ) database_url = args.database_url or os.getenv("DATABASE_URL") - if not api_key: - console.print("[red]✗ LANGSMITH_API_KEY required[/red]") + # Validate keys based on the command + needs_source = bool(args.export_json or args.export_postgres or args.full or args.validate) + needs_target = bool(args.import_json or args.full or args.validate) + + if needs_source and not source_api_key: + console.print("[red]✗ Source API key required (--source-api-key or LANGSMITH_SOURCE_API_KEY or LANGSMITH_API_KEY)[/red]") + sys.exit(1) + if needs_target and not target_api_key: + console.print("[red]✗ Target API key required (--target-api-key or LANGSMITH_TARGET_API_KEY or LANGSMITH_API_KEY)[/red]") sys.exit(1) # Display header @@ -382,6 +409,10 @@ def main() -> None: border_style="magenta", )) + # Show key info (cross-org vs same-org) + if needs_source and needs_target and source_api_key != target_api_key: + console.print("[cyan]ℹ Cross-org migration: using separate API keys for source and target[/cyan]\n") + if args.test_single: console.print("[yellow]⚠ TEST MODE: Single thread only[/yellow]\n") @@ -392,7 +423,10 @@ def main() -> None: console.print("[red]✗ --source-url required[/red]") sys.exit(1) asyncio.run(run_export_json( - args.source_url, api_key, args.export_json, args.test_single + source_url=args.source_url, + source_api_key=source_api_key, + output_file=args.export_json, + test_single=args.test_single, )) elif args.export_postgres: @@ -403,7 +437,11 @@ def main() -> None: console.print("[red]✗ --database-url or DATABASE_URL required[/red]") sys.exit(1) asyncio.run(run_export_postgres( - args.source_url, api_key, database_url, args.backup_file, args.test_single + source_url=args.source_url, + source_api_key=source_api_key, + database_url=database_url, + output_file=args.backup_file, + test_single=args.test_single, )) elif args.import_json: @@ -411,7 +449,10 @@ def main() -> None: console.print("[red]✗ --target-url required[/red]") sys.exit(1) asyncio.run(run_import_json( - args.target_url, api_key, args.import_json, args.dry_run + target_url=args.target_url, + target_api_key=target_api_key, + input_file=args.import_json, + dry_run=args.dry_run, )) elif args.full: @@ -419,15 +460,25 @@ def main() -> None: console.print("[red]✗ Both --source-url and --target-url required[/red]") sys.exit(1) asyncio.run(run_full_migration( - args.source_url, args.target_url, api_key, - args.backup_file, args.dry_run, args.test_single + source_url=args.source_url, + target_url=args.target_url, + source_api_key=source_api_key, + target_api_key=target_api_key, + backup_file=args.backup_file, + dry_run=args.dry_run, + test_single=args.test_single, )) elif args.validate: if not args.source_url or not args.target_url: console.print("[red]✗ Both --source-url and --target-url required[/red]") sys.exit(1) - asyncio.run(run_validate(args.source_url, args.target_url, api_key)) + asyncio.run(run_validate( + source_url=args.source_url, + target_url=args.target_url, + source_api_key=source_api_key, + target_api_key=target_api_key, + )) else: parser.print_help() From b928184dea597b3374e3fd237b96fc374b0906f2 Mon Sep 17 00:00:00 2001 From: MarioAlessandroNapoli Date: Wed, 18 Feb 2026 20:05:24 +0100 Subject: [PATCH 3/4] feat: metadata filter, retry with backoff, history pagination, streaming JSON (AE-261, AE-262) - Add --metadata-filter for server-side JSONB containment filtering - Add --history-limit for optional checkpoint cap per thread - Retry API calls with exponential backoff + jitter (3 attempts) - Paginate checkpoint history via `before` cursor (no more limit=100) - Rewrite JSONExporter for streaming writes (no in-memory buffering) - Update README with new flags, examples, and feature docs --- README.md | 29 +++++- langgraph_export/client.py | 70 ++++++++++++- langgraph_export/exporters/json_exporter.py | 110 +++++++++----------- langgraph_export/migrator.py | 75 +++++++++---- migrate_threads.py | 67 ++++++++++-- 5 files changed, 264 insertions(+), 87 deletions(-) diff --git a/README.md b/README.md index 29da986..47c7adc 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,10 @@ This tool solves all of these problems. - **Test mode** - Export/migrate a single thread first to verify everything works - **Dry-run mode** - Preview changes without making any modifications - **Progress tracking** - Real-time progress bars and detailed summaries +- **Metadata filtering** - Export only threads matching specific metadata (e.g., by workspace) +- **Retry with backoff** - Automatic retries on API failures (3 attempts with exponential backoff) +- **Streaming JSON** - Memory-efficient export that writes threads incrementally to disk +- **Full history pagination** - Fetches complete checkpoint history regardless of thread size ## Use Cases @@ -31,6 +35,7 @@ This tool solves all of these problems. | Store in your own PostgreSQL | `--export-postgres` | | Environment migration (staging → prod) | `--migrate` | | Disaster recovery | `--import-json backup.json` | +| Export single workspace (multi-tenant) | `--metadata-filter '{"workspace_id": 4}'` | ## Installation @@ -128,6 +133,24 @@ python migrate_threads.py \ --import-json threads_backup.json ``` +### Filter by metadata + +Export only threads matching specific metadata (useful for multi-tenant deployments): + +```bash +# Export threads for a specific workspace +python migrate_threads.py \ + --source-url https://my-deployment.langgraph.app \ + --export-json workspace_4.json \ + --metadata-filter '{"workspace_id": 4}' + +# Export threads owned by a specific user +python migrate_threads.py \ + --source-url https://my-deployment.langgraph.app \ + --export-json user_backup.json \ + --metadata-filter '{"owner": "user@example.com"}' +``` + ### Test with a single thread first Always recommended before a full operation: @@ -157,6 +180,8 @@ python migrate_threads.py \ | `--validate` | Compare source vs target thread counts | | `--dry-run` | Simulation mode (no changes made) | | `--test-single` | Process only one thread (for testing) | +| `--metadata-filter JSON` | Filter threads by metadata (JSONB containment) | +| `--history-limit N` | Max checkpoints per thread (default: all) | ## PostgreSQL Schema @@ -224,9 +249,9 @@ Remember to re-enable authentication after! The tool preserves `metadata.owner`, so each user will only see their own threads after migration. -### Rate Limiting +### Rate Limiting & Retry -Built-in delays (0.2-0.3s) prevent API overload. For large exports (1000+ threads), consider running during off-peak hours. +Built-in delays (0.2s) between API calls prevent overload. Failed API calls are retried up to 3 times with exponential backoff (1s, 2s, 4s + jitter). For large exports (1000+ threads), consider running during off-peak hours. ## Troubleshooting diff --git a/langgraph_export/client.py b/langgraph_export/client.py index c976796..f5cbb08 100644 --- a/langgraph_export/client.py +++ b/langgraph_export/client.py @@ -70,20 +70,84 @@ async def get_thread_history( self, thread_id: str, limit: int = 100, + before: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: """ - Get full thread history (checkpoints). + Get thread history (checkpoints) — single page. Args: thread_id: Thread ID limit: Maximum number of checkpoints to return + before: Checkpoint cursor for pagination Returns: - List of checkpoint dictionaries + List of checkpoint dictionaries (most-recent-first) """ - history = await self._client.threads.get_history(thread_id, limit=limit) + kwargs: Dict[str, Any] = {"limit": limit} + if before: + kwargs["before"] = before + + history = await self._client.threads.get_history(thread_id, **kwargs) return list(history) if history else [] + async def get_all_history( + self, + thread_id: str, + limit: Optional[int] = None, + page_size: int = 100, + ) -> List[Dict[str, Any]]: + """ + Get complete thread history with automatic pagination. + + Paginates using the `before` cursor until all checkpoints + are retrieved or the optional limit is reached. + + Args: + thread_id: Thread ID + limit: Max total checkpoints (None = all) + page_size: Checkpoints per API call + + Returns: + List of all checkpoint dictionaries (most-recent-first) + """ + all_history: List[Dict[str, Any]] = [] + before = None + + while True: + batch_size = page_size + if limit: + remaining = limit - len(all_history) + batch_size = min(page_size, remaining) + + batch = await self.get_thread_history( + thread_id, limit=batch_size, before=before, + ) + + if not batch: + break + + all_history.extend(batch) + + if limit and len(all_history) >= limit: + break + + if len(batch) < batch_size: + break + + # Cursor: use the last (oldest) checkpoint in this batch + last = batch[-1] + checkpoint_config = last.get("checkpoint", {}) + if not checkpoint_config: + # Fallback: build cursor from checkpoint_id + cp_id = last.get("checkpoint_id") + if not cp_id: + break + checkpoint_config = {"checkpoint_id": cp_id} + + before = checkpoint_config + + return all_history + async def create_thread( self, thread_id: str, diff --git a/langgraph_export/exporters/json_exporter.py b/langgraph_export/exporters/json_exporter.py index 6c998fe..b8bb255 100644 --- a/langgraph_export/exporters/json_exporter.py +++ b/langgraph_export/exporters/json_exporter.py @@ -1,79 +1,90 @@ """ JSON Exporter -Export threads to a JSON file. +Export threads to a JSON file with streaming writes. +Threads are written incrementally to avoid holding all data in memory. """ import json from datetime import datetime from pathlib import Path -from typing import Any, Dict, List +from typing import Any, Dict, IO, List, Optional from langgraph_export.exporters.base import BaseExporter, ExportStats, ThreadData +_JSON_INDENT = 2 + class JSONExporter(BaseExporter): """ - Export threads to JSON file. + Export threads to JSON file with streaming writes. - Creates a JSON file with all threads, metadata, and checkpoints. + Writes each thread to disk as it's exported, rather than + buffering everything in memory. The output format is + backward-compatible with the original buffered exporter. """ def __init__( self, source_url: str, output_file: str = "threads_backup.json", - indent: int = 2, ): - """ - Initialize JSON exporter. - - Args: - source_url: Source LangGraph deployment URL - output_file: Path to output JSON file - indent: JSON indentation level (default: 2) - """ super().__init__(source_url) self.output_file = Path(output_file) - self.indent = indent - self._threads: List[Dict[str, Any]] = [] + self._file: Optional[IO[str]] = None + self._thread_count = 0 async def connect(self) -> None: - """No connection needed for JSON export.""" - pass + """Open the file and write the JSON header.""" + self.output_file.parent.mkdir(parents=True, exist_ok=True) + self._file = open(self.output_file, "w", encoding="utf-8") + self._thread_count = 0 + # Start the JSON object with threads array + self._file.write("{\n") + self._file.write(f' "threads": [\n') async def export_thread(self, thread: ThreadData) -> None: - """ - Add thread to export buffer. + """Write a single thread to the file immediately.""" + if not self._file: + raise RuntimeError("Exporter not connected. Call connect() first.") + if self._thread_count > 0: + self._file.write(",\n") + + thread_json = json.dumps( + thread.to_dict(), indent=_JSON_INDENT, ensure_ascii=False, default=str, + ) + # Indent each line of the thread JSON by 4 spaces (inside the array) + indented = "\n".join(f" {line}" for line in thread_json.splitlines()) + self._file.write(indented) - Args: - thread: Thread data to export - """ - self._threads.append(thread.to_dict()) + self._thread_count += 1 self.stats.threads_exported += 1 self.stats.checkpoints_exported += len(thread.history) async def finalize(self) -> None: - """Save all threads to JSON file.""" - export_data = { - "export_date": datetime.now().isoformat(), - "source": self.source_url, - "total_threads": len(self._threads), - "total_checkpoints": self.stats.checkpoints_exported, - "threads": self._threads, - } + """Close the threads array and write metadata footer.""" + if not self._file: + return - # Ensure parent directory exists - self.output_file.parent.mkdir(parents=True, exist_ok=True) + # Close the threads array + if self._thread_count > 0: + self._file.write("\n") + self._file.write(" ],\n") - # Write JSON file - self.output_file.write_text( - json.dumps(export_data, indent=self.indent, ensure_ascii=False, default=str) - ) + # Write metadata at the end (known only after all threads are exported) + self._file.write(f' "export_date": {json.dumps(datetime.now().isoformat())},\n') + self._file.write(f' "source": {json.dumps(self.source_url)},\n') + self._file.write(f' "total_threads": {self.stats.threads_exported},\n') + self._file.write(f' "total_checkpoints": {self.stats.checkpoints_exported}\n') + self._file.write("}\n") + + self._file.flush() async def close(self) -> None: - """Clear internal buffer.""" - self._threads.clear() + """Close the file handle.""" + if self._file: + self._file.close() + self._file = None def get_file_size_mb(self) -> float: """Get output file size in MB.""" @@ -83,35 +94,18 @@ def get_file_size_mb(self) -> float: @staticmethod def load_threads(file_path: str) -> List[ThreadData]: - """ - Load threads from JSON file. - - Args: - file_path: Path to JSON file - - Returns: - List of ThreadData objects - """ + """Load threads from JSON file.""" path = Path(file_path) if not path.exists(): raise FileNotFoundError(f"File not found: {file_path}") data = json.loads(path.read_text()) threads = data.get("threads", []) - return [ThreadData.from_dict(t) for t in threads] @staticmethod def get_export_info(file_path: str) -> Dict[str, Any]: - """ - Get export metadata from JSON file. - - Args: - file_path: Path to JSON file - - Returns: - Dictionary with export date, source, and counts - """ + """Get export metadata from JSON file.""" path = Path(file_path) if not path.exists(): raise FileNotFoundError(f"File not found: {file_path}") diff --git a/langgraph_export/migrator.py b/langgraph_export/migrator.py index cab04cd..d52000c 100644 --- a/langgraph_export/migrator.py +++ b/langgraph_export/migrator.py @@ -5,6 +5,8 @@ """ import asyncio +import logging +import random from typing import Any, Callable, Dict, List, Optional from langgraph_sdk.errors import ConflictError, APIStatusError @@ -78,6 +80,29 @@ async def close(self) -> None: for exporter in self._exporters: await exporter.close() + @staticmethod + async def _retry(coro_factory, max_attempts=3, base_delay=1.0, label=""): + """ + Retry an async call with exponential backoff + jitter. + + Args: + coro_factory: Callable that returns a new coroutine each call + max_attempts: Maximum number of attempts + base_delay: Base delay in seconds (doubles each retry) + label: Label for log messages + """ + last_error = None + for attempt in range(max_attempts): + try: + return await coro_factory() + except Exception as e: + last_error = e + if attempt < max_attempts - 1: + delay = base_delay * (2 ** attempt) * (0.7 + random.random() * 0.6) + logging.warning(f"Retry {attempt + 1}/{max_attempts} for {label}: {e} (wait {delay:.1f}s)") + await asyncio.sleep(delay) + raise last_error + def add_json_exporter(self, output_file: str = "threads_backup.json") -> JSONExporter: """ Add JSON file exporter. @@ -109,6 +134,8 @@ def add_postgres_exporter(self, database_url: str) -> PostgresExporter: async def fetch_all_threads( self, limit: Optional[int] = None, + metadata_filter: Optional[Dict[str, Any]] = None, + history_limit: Optional[int] = None, progress_callback: Optional[Callable[[int, str], None]] = None, ) -> List[ThreadData]: """ @@ -116,6 +143,8 @@ async def fetch_all_threads( Args: limit: Maximum number of threads to fetch (None = all) + metadata_filter: Optional metadata dict for server-side filtering + history_limit: Max checkpoints per thread (None = all) progress_callback: Optional callback(count, message) for progress updates Returns: @@ -128,14 +157,15 @@ async def fetch_all_threads( offset = 0 batch_size = 100 if limit is None else min(limit, 100) - # Fetch thread list while True: if progress_callback: progress_callback(len(all_threads), f"Fetching threads (offset={offset})...") - threads = await self._source_client.search_threads( - limit=batch_size, - offset=offset, + threads = await self._retry( + lambda o=offset: self._source_client.search_threads( + limit=batch_size, offset=o, metadata=metadata_filter, + ), + label="search_threads", ) if not threads: @@ -147,9 +177,16 @@ async def fetch_all_threads( continue try: - # Get full thread details - details = await self._source_client.get_thread(thread_id) - history = await self._source_client.get_thread_history(thread_id) + details = await self._retry( + lambda tid=thread_id: self._source_client.get_thread(tid), + label=f"get_thread({thread_id[:8]})", + ) + history = await self._retry( + lambda tid=thread_id: self._source_client.get_all_history( + tid, limit=history_limit, + ), + label=f"get_all_history({thread_id[:8]})", + ) thread_data = ThreadData( thread_id=thread_id, @@ -169,18 +206,16 @@ async def fetch_all_threads( except Exception as e: if progress_callback: - progress_callback(len(all_threads), f"Error: {thread_id}: {e}") + progress_callback(len(all_threads), f"Skipped {thread_id[:8]}: {e}") continue await asyncio.sleep(self.rate_limit_delay) - # Check limit if limit and len(all_threads) >= limit: return all_threads offset += len(threads) - # No more pages if len(threads) < batch_size: break @@ -192,6 +227,8 @@ async def export_threads( self, threads: Optional[List[ThreadData]] = None, limit: Optional[int] = None, + metadata_filter: Optional[Dict[str, Any]] = None, + history_limit: Optional[int] = None, progress_callback: Optional[Callable[[int, str], None]] = None, ) -> ExportStats: """ @@ -200,14 +237,20 @@ async def export_threads( Args: threads: Pre-fetched threads (if None, fetches from source) limit: Maximum number of threads to export + metadata_filter: Optional metadata dict for server-side filtering + history_limit: Max checkpoints per thread (None = all) progress_callback: Optional callback for progress updates Returns: Combined export statistics """ - # Fetch threads if not provided if threads is None: - threads = await self.fetch_all_threads(limit, progress_callback) + threads = await self.fetch_all_threads( + limit=limit, + metadata_filter=metadata_filter, + history_limit=history_limit, + progress_callback=progress_callback, + ) # Connect all exporters for exporter in self._exporters: @@ -514,14 +557,10 @@ async def validate_migration( # Validate checkpoint history for a sample thread if check_history and sample_thread_id: if self._source_client: - src_history = await self._source_client.get_thread_history( - sample_thread_id, limit=1000 - ) + src_history = await self._source_client.get_all_history(sample_thread_id) result["history_source"] = len(src_history) if self._target_client: - tgt_history = await self._target_client.get_thread_history( - sample_thread_id, limit=1000 - ) + tgt_history = await self._target_client.get_all_history(sample_thread_id) result["history_target"] = len(tgt_history) return result diff --git a/migrate_threads.py b/migrate_threads.py index 6480307..db9b72d 100644 --- a/migrate_threads.py +++ b/migrate_threads.py @@ -13,8 +13,10 @@ import argparse import asyncio +import json import os import sys +from typing import Optional from dotenv import load_dotenv from rich.console import Console @@ -36,6 +38,8 @@ async def run_export_json( source_api_key: str, output_file: str, test_single: bool = False, + metadata_filter: Optional[dict] = None, + history_limit: Optional[int] = None, ) -> None: """Export threads to JSON file.""" console.print(Panel.fit( @@ -43,6 +47,9 @@ async def run_export_json( border_style="cyan", )) + if metadata_filter: + console.print(f"[cyan]Metadata filter:[/cyan] {metadata_filter}") + async with ThreadMigrator(source_url=source_url, source_api_key=source_api_key) as migrator: json_exporter = migrator.add_json_exporter(output_file) @@ -59,9 +66,13 @@ def update_progress(count: int, message: str) -> None: progress.update(task, completed=count, description=f"[cyan]{message}") limit = 1 if test_single else None - stats = await migrator.export_threads(limit=limit, progress_callback=update_progress) + stats = await migrator.export_threads( + limit=limit, + metadata_filter=metadata_filter, + history_limit=history_limit, + progress_callback=update_progress, + ) - # Display results console.print(f"\n[green]✓[/green] Threads exported: {stats.threads_exported}") console.print(f"[green]✓[/green] Checkpoints exported: {stats.checkpoints_exported}") console.print(f"[green]✓[/green] Output file: [bold]{output_file}[/bold]") @@ -74,6 +85,8 @@ async def run_export_postgres( database_url: str, output_file: str, test_single: bool = False, + metadata_filter: Optional[dict] = None, + history_limit: Optional[int] = None, ) -> None: """Export threads to PostgreSQL (and JSON backup).""" console.print(Panel.fit( @@ -81,6 +94,9 @@ async def run_export_postgres( border_style="cyan", )) + if metadata_filter: + console.print(f"[cyan]Metadata filter:[/cyan] {metadata_filter}") + async with ThreadMigrator(source_url=source_url, source_api_key=source_api_key) as migrator: json_exporter = migrator.add_json_exporter(output_file) pg_exporter = migrator.add_postgres_exporter(database_url) @@ -98,7 +114,12 @@ def update_progress(count: int, message: str) -> None: progress.update(task, completed=count, description=f"[cyan]{message}") limit = 1 if test_single else None - stats = await migrator.export_threads(limit=limit, progress_callback=update_progress) + stats = await migrator.export_threads( + limit=limit, + metadata_filter=metadata_filter, + history_limit=history_limit, + progress_callback=update_progress, + ) # Get PostgreSQL stats db_stats = await pg_exporter.get_database_stats() @@ -174,14 +195,18 @@ async def run_full_migration( backup_file: str, dry_run: bool = False, test_single: bool = False, + metadata_filter: Optional[dict] = None, + history_limit: Optional[int] = None, ) -> None: """Full migration: export + import + validate.""" - # Phase 1: Export console.print(Panel.fit( "[bold cyan]Phase 1: Export threads from source[/bold cyan]", border_style="cyan", )) + if metadata_filter: + console.print(f"[cyan]Metadata filter:[/cyan] {metadata_filter}") + async with ThreadMigrator( source_url=source_url, target_url=target_url, @@ -203,9 +228,15 @@ def update_progress(count: int, message: str) -> None: progress.update(task, completed=count, description=f"[cyan]{message}") limit = 1 if test_single else None - threads = await migrator.fetch_all_threads(limit=limit, progress_callback=update_progress) + threads = await migrator.fetch_all_threads( + limit=limit, + metadata_filter=metadata_filter, + history_limit=history_limit, + progress_callback=update_progress, + ) - # Export to JSON + # Export to JSON backup + await json_exporter.connect() for thread in threads: await json_exporter.export_thread(thread) await json_exporter.finalize() @@ -375,6 +406,10 @@ def main() -> None: parser.add_argument("--backup-file", default="threads_backup.json", help="Backup file path") parser.add_argument("--dry-run", action="store_true", help="Simulation mode") parser.add_argument("--test-single", action="store_true", help="Test with single thread") + parser.add_argument("--metadata-filter", metavar="JSON", + help="Filter threads by metadata (JSON string, e.g. '{\"workspace_id\": 4}')") + parser.add_argument("--history-limit", type=int, default=None, + help="Max checkpoints per thread (default: all)") args = parser.parse_args() @@ -392,6 +427,20 @@ def main() -> None: ) database_url = args.database_url or os.getenv("DATABASE_URL") + # Parse metadata filter + metadata_filter = None + if args.metadata_filter: + try: + metadata_filter = json.loads(args.metadata_filter) + if not isinstance(metadata_filter, dict): + console.print("[red]✗ --metadata-filter must be a JSON object[/red]") + sys.exit(1) + except json.JSONDecodeError as e: + console.print(f"[red]✗ Invalid JSON in --metadata-filter: {e}[/red]") + sys.exit(1) + + history_limit = args.history_limit + # Validate keys based on the command needs_source = bool(args.export_json or args.export_postgres or args.full or args.validate) needs_target = bool(args.import_json or args.full or args.validate) @@ -427,6 +476,8 @@ def main() -> None: source_api_key=source_api_key, output_file=args.export_json, test_single=args.test_single, + metadata_filter=metadata_filter, + history_limit=history_limit, )) elif args.export_postgres: @@ -442,6 +493,8 @@ def main() -> None: database_url=database_url, output_file=args.backup_file, test_single=args.test_single, + metadata_filter=metadata_filter, + history_limit=history_limit, )) elif args.import_json: @@ -467,6 +520,8 @@ def main() -> None: backup_file=args.backup_file, dry_run=args.dry_run, test_single=args.test_single, + metadata_filter=metadata_filter, + history_limit=history_limit, )) elif args.validate: From ac26f39274bfa4466dc14b88661b279bdf7efd60 Mon Sep 17 00:00:00 2001 From: MarioAlessandroNapoli Date: Wed, 18 Feb 2026 23:14:14 +0100 Subject: [PATCH 4/4] feat: concurrent export, pagination bugfix, streaming writes, legacy import fix Critical bugfix: - Fix history pagination cursor format: server expects {"configurable": {"checkpoint_id": ...}}, not flat {"checkpoint_id": ...}. Without this fix, every page after the first returns a 500 error, silently truncating exports to ~100 checkpoints per thread. Performance: - Concurrent thread fetching with asyncio.Semaphore (configurable --concurrency, default 5) - Streaming JSON export via producer-consumer queue (constant memory usage) - Per-page retry with exponential backoff + jitter (instead of retrying entire history) New features: - --legacy-terminal-node: specify graph terminal node for legacy imports, ensuring next=[] so threads are continuable after migration - --concurrency N: control parallel thread fetches - Rich progress bars with per-thread detail and elapsed time Fixes: - JSON loader uses strict=False to handle control characters in agent messages - Removed unused imports and dead code - Updated README with complete command reference and troubleshooting --- README.md | 205 ++++----- langgraph_export/client.py | 165 ++----- langgraph_export/exporters/json_exporter.py | 4 +- langgraph_export/migrator.py | 462 +++++++++----------- migrate_threads.py | 442 ++++++++----------- 5 files changed, 526 insertions(+), 752 deletions(-) diff --git a/README.md b/README.md index 47c7adc..d43a1e0 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ -# LangGraph Threads Export Tool +# LangGraph Threads Migration Tool -A Python tool to export threads, checkpoints, and conversation history from LangGraph Cloud deployments. Save your data to JSON files, PostgreSQL databases, or migrate directly to another deployment. +A Python tool to export, backup, and migrate threads with full checkpoint history between LangGraph Cloud deployments. ## Why This Tool? LangGraph Cloud stores your conversation threads and checkpoints, but there's no built-in way to: - **Backup your data** before deleting a deployment -- **Migrate conversations** between environments (prod → dev) +- **Migrate conversations** between environments (prod → staging, or across orgs) - **Store threads in your own database** for analytics or compliance - **Download conversation history** as JSON for processing @@ -14,39 +14,26 @@ This tool solves all of these problems. ## Features -- **Export to JSON** - Download all threads and checkpoints as a backup file -- **Export to PostgreSQL** - Store threads in your own database with proper schema -- **Migrate between deployments** - Transfer threads from one LangGraph Cloud deployment to another -- **Preserve everything** - Thread IDs, metadata (including `owner` for multi-tenancy), checkpoints, and conversation history -- **Test mode** - Export/migrate a single thread first to verify everything works -- **Dry-run mode** - Preview changes without making any modifications -- **Progress tracking** - Real-time progress bars and detailed summaries -- **Metadata filtering** - Export only threads matching specific metadata (e.g., by workspace) -- **Retry with backoff** - Automatic retries on API failures (3 attempts with exponential backoff) -- **Streaming JSON** - Memory-efficient export that writes threads incrementally to disk -- **Full history pagination** - Fetches complete checkpoint history regardless of thread size - -## Use Cases - -| Scenario | Command | -|----------|---------| -| Backup before deleting deployment | `--export-json backup.json` | -| Cost optimization (expensive → cheaper deployment) | `--full` | -| Store in your own PostgreSQL | `--export-postgres` | -| Environment migration (staging → prod) | `--migrate` | -| Disaster recovery | `--import-json backup.json` | -| Export single workspace (multi-tenant) | `--metadata-filter '{"workspace_id": 4}'` | +- **Export to JSON** — Streaming writes, memory-efficient even for multi-GB exports +- **Export to PostgreSQL** — Store threads in your own database with proper schema +- **Migrate between deployments** — Full migration with supersteps (preserves checkpoint chains) and automatic legacy fallback +- **Concurrent fetching** — Configurable parallelism (default: 5 threads) for fast exports +- **Per-page retry** — Automatic retries with exponential backoff on paginated API calls +- **History pagination** — Correct cursor format (`{"configurable": {"checkpoint_id": ...}}`) for complete checkpoint retrieval +- **Legacy import fix** — `--legacy-terminal-node` sets `next=[]` on threads imported via fallback mode +- **Rich progress bars** — Real-time progress with thread counts, elapsed time, and per-thread details +- **Metadata filtering** — Export only threads matching specific metadata (e.g., by workspace) +- **Dry-run & test modes** — Preview changes or test with a single thread before full operations +- **Cross-org support** — Separate API keys for source and target deployments ## Installation ```bash -# Clone the repository -git clone https://github.com/YOUR_USERNAME/langgraph-threads-migration.git +git clone https://github.com/farouk09/langgraph-threads-migration.git cd langgraph-threads-migration # Using uv (recommended) -uv venv -source .venv/bin/activate +uv venv && source .venv/bin/activate uv pip install -r requirements.txt # Or using pip @@ -55,8 +42,6 @@ pip install -r requirements.txt ## Configuration -Create a `.env` file: - ```bash cp .env.example .env ``` @@ -79,9 +64,7 @@ DATABASE_URL=postgresql://user:password@localhost:5432/dbname ## Usage -### Export to JSON file - -Download all threads and checkpoints as a backup: +### Export to JSON ```bash python migrate_threads.py \ @@ -91,27 +74,19 @@ python migrate_threads.py \ ### Export to PostgreSQL -Store threads in your own database: - ```bash python migrate_threads.py \ --source-url https://my-deployment.langgraph.app \ --export-postgres ``` -This creates two tables: -- `langgraph_threads` - Thread metadata and current state -- `langgraph_checkpoints` - Full checkpoint history - -### Migrate between deployments - -Transfer all threads from one deployment to another: +### Full migration (export + import + validate) ```bash # Same org (shared API key) python migrate_threads.py \ - --source-url https://my-prod.langgraph.app \ - --target-url https://my-dev.langgraph.app \ + --source-url https://old-deploy.langgraph.app \ + --target-url https://new-deploy.langgraph.app \ --full # Cross-org (separate API keys) @@ -125,17 +100,24 @@ python migrate_threads.py \ ### Import from JSON -Restore threads from a backup file: - ```bash python migrate_threads.py \ --target-url https://my-deployment.langgraph.app \ --import-json threads_backup.json ``` -### Filter by metadata +### Import with legacy terminal node fix -Export only threads matching specific metadata (useful for multi-tenant deployments): +When importing threads that fall back to legacy mode (no supersteps), the thread's `next` field may point to the graph's entry node instead of being empty. Use `--legacy-terminal-node` to specify your graph's terminal node, which sets `next=[]` so threads are continuable: + +```bash +python migrate_threads.py \ + --target-url https://my-deployment.langgraph.app \ + --import-json backup.json \ + --legacy-terminal-node "MyLastNode.after_handler" +``` + +### Filter by metadata ```bash # Export threads for a specific workspace @@ -143,18 +125,10 @@ python migrate_threads.py \ --source-url https://my-deployment.langgraph.app \ --export-json workspace_4.json \ --metadata-filter '{"workspace_id": 4}' - -# Export threads owned by a specific user -python migrate_threads.py \ - --source-url https://my-deployment.langgraph.app \ - --export-json user_backup.json \ - --metadata-filter '{"owner": "user@example.com"}' ``` ### Test with a single thread first -Always recommended before a full operation: - ```bash python migrate_threads.py \ --source-url https://my-prod.langgraph.app \ @@ -168,27 +142,48 @@ python migrate_threads.py \ |----------|-------------| | `--source-url` | Source LangGraph Cloud deployment URL | | `--target-url` | Target LangGraph Cloud deployment URL | -| `--api-key` | Shared API key fallback (or `LANGSMITH_API_KEY`) | +| `--api-key` | Shared API key fallback (or `LANGSMITH_API_KEY` env var) | | `--source-api-key` | Source API key for cross-org (or `LANGSMITH_SOURCE_API_KEY`) | | `--target-api-key` | Target API key for cross-org (or `LANGSMITH_TARGET_API_KEY`) | -| `--database-url` | PostgreSQL URL (or set in `.env`) | +| `--database-url` | PostgreSQL URL (or `DATABASE_URL` env var) | | `--export-json FILE` | Export threads to JSON file | | `--export-postgres` | Export threads to PostgreSQL database | | `--import-json FILE` | Import threads from JSON file | -| `--migrate` | Migrate threads (export + import) | | `--full` | Full migration (export + import + validate) | | `--validate` | Compare source vs target thread counts | | `--dry-run` | Simulation mode (no changes made) | | `--test-single` | Process only one thread (for testing) | -| `--metadata-filter JSON` | Filter threads by metadata (JSONB containment) | +| `--metadata-filter JSON` | Filter threads by metadata (JSON object) | | `--history-limit N` | Max checkpoints per thread (default: all) | +| `--concurrency N` | Parallel thread fetches (default: 5) | +| `--legacy-terminal-node NODE` | Graph terminal node name for legacy imports (sets `next=[]`) | +| `--backup-file FILE` | Backup file path (default: `threads_backup.json`) | + +## Import Strategies + +The tool uses two import strategies, with automatic fallback: + +### 1. Supersteps (preferred) +Replays state changes via `threads.create(supersteps=...)`, preserving the full checkpoint chain. This enables time-travel operations on the target deployment. + +### 2. Legacy fallback +When supersteps fail (e.g., due to incompatible serialized objects in old checkpoints), the tool falls back to `create_thread()` + `update_thread_state()`. This preserves the final state but creates only a single checkpoint. + +**Known issue with legacy import**: Without `--legacy-terminal-node`, threads imported in legacy mode may have `next=['SomeMiddleware.before_handler']` instead of `next=[]`, making them non-continuable. The flag fixes this by telling LangGraph which node "last ran". + +## Key Bug Fixes (vs upstream) + +### History pagination cursor format +The LangGraph Cloud API expects the `before` cursor in the format `{"configurable": {"checkpoint_id": "..."}}`, not `{"checkpoint_id": "..."}`. The incorrect format caused silent 500 errors on every page after the first, resulting in incomplete exports. This fix is critical for any thread with more than 100 checkpoints. + +### JSON parsing of agent messages +Agent messages may contain unescaped control characters (e.g., `\n` inside JSON strings). The exporter now uses `strict=False` when loading JSON backups to handle these correctly. ## PostgreSQL Schema When using `--export-postgres`, the tool creates: ```sql --- Threads table CREATE TABLE langgraph_threads ( id SERIAL PRIMARY KEY, thread_id VARCHAR(255) UNIQUE NOT NULL, @@ -200,7 +195,6 @@ CREATE TABLE langgraph_threads ( exported_at TIMESTAMP DEFAULT NOW() ); --- Checkpoints table CREATE TABLE langgraph_checkpoints ( id SERIAL PRIMARY KEY, thread_id VARCHAR(255) REFERENCES langgraph_threads(thread_id), @@ -213,73 +207,46 @@ CREATE TABLE langgraph_checkpoints ( ); ``` -## Example Output - -``` -╭────────────────────────────────────────╮ -│ 🔄 LangGraph Threads Export Tool │ -╰────────────────────────────────────────╯ - -╭─────────────────────────────────────────╮ -│ Phase 1: Export threads from source │ -╰─────────────────────────────────────────╯ -✓ 66 threads found -✓ JSON backup saved: threads_backup.json -✓ Size: 29.30 MB -✓ Total checkpoints exported: 842 -✓ PostgreSQL: 66 threads, 842 checkpoints -``` - -## Important Notes - -### Authentication - -If your LangGraph deployment uses custom authentication (e.g., Auth0), you may need to temporarily disable it during export: - -```json -// langgraph.json - temporarily set auth to null -{ - "auth": null -} -``` - -Remember to re-enable authentication after! - -### Multi-tenancy - -The tool preserves `metadata.owner`, so each user will only see their own threads after migration. - -### Rate Limiting & Retry - -Built-in delays (0.2s) between API calls prevent overload. Failed API calls are retried up to 3 times with exponential backoff (1s, 2s, 4s + jitter). For large exports (1000+ threads), consider running during off-peak hours. - -## Troubleshooting - -| Error | Solution | -|-------|----------| -| `PermissionDeniedError` | Use Service Key (`lsv2_sk_...`), not Personal Token | -| `ConflictError (409)` | Thread already exists (automatically skipped) | -| `asyncpg not installed` | Run `pip install asyncpg` for PostgreSQL support | - ## Project Structure ``` langgraph-threads-migration/ -├── migrate_threads.py # CLI entry point -├── langgraph_export/ # Main package +├── migrate_threads.py # CLI entry point (Rich progress bars) +├── langgraph_export/ │ ├── __init__.py -│ ├── client.py # LangGraph SDK wrapper -│ ├── migrator.py # Thread migration orchestrator +│ ├── client.py # LangGraph SDK wrapper (per-page retry, cursor fix) +│ ├── migrator.py # Migration orchestrator (concurrent fetch, supersteps, legacy) │ ├── models.py # SQLAlchemy models for PostgreSQL -│ └── exporters/ # Export backends +│ └── exporters/ │ ├── base.py # Abstract base exporter -│ ├── json_exporter.py # JSON file export +│ ├── json_exporter.py # Streaming JSON export │ └── postgres_exporter.py # PostgreSQL export ├── requirements.txt ├── .env.example └── LICENSE ``` +## Troubleshooting + +| Error | Solution | +|-------|----------| +| `PermissionDeniedError` | Use Service Key (`lsv2_sk_...`), not Personal Token | +| `ConflictError (409)` | Thread already exists on target (automatically skipped) | +| `asyncpg not installed` | Run `pip install asyncpg` for PostgreSQL support | +| `KeyError: 'configurable'` | Server-side error from wrong pagination cursor — fixed in this fork | +| `next != []` after import | Use `--legacy-terminal-node` with your graph's terminal node | + +## Important Notes + +### Authentication +If your LangGraph deployment uses custom authentication, you may need to temporarily disable it during export. + +### Multi-tenancy +The tool preserves thread metadata including `owner`, so each user will only see their own threads after migration. + +### Rate Limiting +Built-in delays (0.1s) between API calls prevent overload. Failed calls are retried up to 3 times with exponential backoff + jitter. + ## Contributing Contributions are welcome! Please feel free to submit a Pull Request. @@ -287,7 +254,3 @@ Contributions are welcome! Please feel free to submit a Pull Request. ## License MIT License - see [LICENSE](LICENSE) file for details. - -## Acknowledgments - -Built for the [LangGraph](https://github.com/langchain-ai/langgraph) community. diff --git a/langgraph_export/client.py b/langgraph_export/client.py index f5cbb08..6266281 100644 --- a/langgraph_export/client.py +++ b/langgraph_export/client.py @@ -1,68 +1,60 @@ """ -LangGraph Cloud API Client - -Wrapper around the official LangGraph SDK for thread operations. +LangGraph Cloud API Client — with per-page retry and concurrency support. """ +import asyncio +import logging +import random from typing import Any, Dict, List, Optional from langgraph_sdk import get_client +logger = logging.getLogger(__name__) + class LangGraphClient: """Client for interacting with LangGraph Cloud API via the official SDK.""" def __init__(self, base_url: str, api_key: str): - """ - Initialize the LangGraph client. - - Args: - base_url: LangGraph Cloud deployment URL - api_key: LangSmith API key (Service Key recommended) - """ self.base_url = base_url.rstrip("/") self.api_key = api_key self._client = get_client(url=base_url, api_key=api_key) async def close(self) -> None: - """Close the HTTP client.""" - # SDK handles cleanup automatically pass + # ── Retry helper ────────────────────────────────────────────── + + @staticmethod + async def _retry(coro_factory, max_attempts=3, base_delay=0.8, label=""): + """Retry with exponential backoff + jitter. Returns result or raises.""" + last_error = None + for attempt in range(max_attempts): + try: + return await coro_factory() + except Exception as e: + last_error = e + if attempt < max_attempts - 1: + delay = base_delay * (2 ** attempt) * (0.7 + random.random() * 0.6) + logger.warning(f"Retry {attempt+1}/{max_attempts} {label}: {e} ({delay:.1f}s)") + await asyncio.sleep(delay) + raise last_error + + # ── Thread operations ───────────────────────────────────────── + async def search_threads( self, limit: int = 100, offset: int = 0, metadata: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: - """ - Search for threads. - - Args: - limit: Maximum number of threads to return - offset: Number of threads to skip - metadata: Optional metadata filter - - Returns: - List of thread dictionaries - """ kwargs = {"limit": limit, "offset": offset} if metadata: kwargs["metadata"] = metadata - threads = await self._client.threads.search(**kwargs) return list(threads) if threads else [] async def get_thread(self, thread_id: str) -> Dict[str, Any]: - """ - Get thread details. - - Args: - thread_id: Thread ID - - Returns: - Thread dictionary with metadata, values, etc. - """ thread = await self._client.threads.get(thread_id) return dict(thread) if thread else {} @@ -72,21 +64,9 @@ async def get_thread_history( limit: int = 100, before: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: - """ - Get thread history (checkpoints) — single page. - - Args: - thread_id: Thread ID - limit: Maximum number of checkpoints to return - before: Checkpoint cursor for pagination - - Returns: - List of checkpoint dictionaries (most-recent-first) - """ kwargs: Dict[str, Any] = {"limit": limit} if before: kwargs["before"] = before - history = await self._client.threads.get_history(thread_id, **kwargs) return list(history) if history else [] @@ -96,22 +76,10 @@ async def get_all_history( limit: Optional[int] = None, page_size: int = 100, ) -> List[Dict[str, Any]]: - """ - Get complete thread history with automatic pagination. - - Paginates using the `before` cursor until all checkpoints - are retrieved or the optional limit is reached. - - Args: - thread_id: Thread ID - limit: Max total checkpoints (None = all) - page_size: Checkpoints per API call - - Returns: - List of all checkpoint dictionaries (most-recent-first) - """ + """Get complete history with per-page retry.""" all_history: List[Dict[str, Any]] = [] before = None + tid_short = thread_id[:8] while True: batch_size = page_size @@ -119,8 +87,13 @@ async def get_all_history( remaining = limit - len(all_history) batch_size = min(page_size, remaining) - batch = await self.get_thread_history( - thread_id, limit=batch_size, before=before, + page_num = len(all_history) // page_size + 1 + + batch = await self._retry( + lambda bs=batch_size, b=before: self.get_thread_history( + thread_id, limit=bs, before=b, + ), + label=f"history({tid_short} p{page_num})", ) if not batch: @@ -130,39 +103,26 @@ async def get_all_history( if limit and len(all_history) >= limit: break - if len(batch) < batch_size: break - # Cursor: use the last (oldest) checkpoint in this batch + # Build cursor for next page — server expects {"configurable": {"checkpoint_id": ...}} last = batch[-1] - checkpoint_config = last.get("checkpoint", {}) - if not checkpoint_config: - # Fallback: build cursor from checkpoint_id - cp_id = last.get("checkpoint_id") - if not cp_id: - break - checkpoint_config = {"checkpoint_id": cp_id} - - before = checkpoint_config + cp_obj = last.get("checkpoint", {}) + cp_id = cp_obj.get("checkpoint_id") or last.get("checkpoint_id") + if not cp_id: + break + before = {"configurable": {"checkpoint_id": cp_id}} return all_history + # ── Write operations ────────────────────────────────────────── + async def create_thread( self, thread_id: str, metadata: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: - """ - Create a new thread with a specific ID. - - Args: - thread_id: Thread ID to use - metadata: Optional metadata to attach - - Returns: - Created thread dictionary - """ thread = await self._client.threads.create( thread_id=thread_id, metadata=metadata or {}, @@ -176,22 +136,6 @@ async def create_thread_with_history( supersteps: Optional[List[Dict[str, Any]]] = None, if_exists: Optional[str] = None, ) -> Dict[str, Any]: - """ - Create a new thread with pre-populated checkpoint history. - - Uses the supersteps parameter to replay state updates, - reconstructing the full checkpoint chain on the target. - - Args: - thread_id: Thread ID to use - metadata: Optional metadata to attach - supersteps: List of supersteps, each containing updates - with values and as_node - if_exists: Conflict behavior ('raise' or 'do_nothing') - - Returns: - Created thread dictionary - """ kwargs: Dict[str, Any] = { "thread_id": thread_id, "metadata": metadata or {}, @@ -200,7 +144,6 @@ async def create_thread_with_history( kwargs["supersteps"] = supersteps if if_exists: kwargs["if_exists"] = if_exists - thread = await self._client.threads.create(**kwargs) return dict(thread) if thread else {} @@ -210,33 +153,11 @@ async def update_thread_state( values: Dict[str, Any], as_node: Optional[str] = None, ) -> Dict[str, Any]: - """ - Update thread state. - - Args: - thread_id: Thread ID - values: State values to update - as_node: Optional node name for the update - - Returns: - Update result dictionary - """ result = await self._client.threads.update_state( - thread_id, - values=values, - as_node=as_node, + thread_id, values=values, as_node=as_node, ) return dict(result) if result else {} async def get_thread_state(self, thread_id: str) -> Dict[str, Any]: - """ - Get current thread state. - - Args: - thread_id: Thread ID - - Returns: - Current state dictionary - """ state = await self._client.threads.get_state(thread_id) return dict(state) if state else {} diff --git a/langgraph_export/exporters/json_exporter.py b/langgraph_export/exporters/json_exporter.py index b8bb255..2887075 100644 --- a/langgraph_export/exporters/json_exporter.py +++ b/langgraph_export/exporters/json_exporter.py @@ -99,7 +99,7 @@ def load_threads(file_path: str) -> List[ThreadData]: if not path.exists(): raise FileNotFoundError(f"File not found: {file_path}") - data = json.loads(path.read_text()) + data = json.loads(path.read_text(encoding="utf-8"), strict=False) threads = data.get("threads", []) return [ThreadData.from_dict(t) for t in threads] @@ -110,7 +110,7 @@ def get_export_info(file_path: str) -> Dict[str, Any]: if not path.exists(): raise FileNotFoundError(f"File not found: {file_path}") - data = json.loads(path.read_text()) + data = json.loads(path.read_text(encoding="utf-8"), strict=False) return { "export_date": data.get("export_date"), "source": data.get("source"), diff --git a/langgraph_export/migrator.py b/langgraph_export/migrator.py index d52000c..8bf0921 100644 --- a/langgraph_export/migrator.py +++ b/langgraph_export/migrator.py @@ -1,12 +1,12 @@ """ -Thread Migrator - -Main class for orchestrating thread export and migration operations. +Thread Migrator — concurrent fetch, streaming export, per-page retry. """ import asyncio import logging import random +import time +from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional from langgraph_sdk.errors import ConflictError, APIStatusError @@ -16,14 +16,34 @@ from langgraph_export.exporters.json_exporter import JSONExporter from langgraph_export.exporters.postgres_exporter import PostgresExporter +logger = logging.getLogger(__name__) -class ThreadMigrator: - """ - Thread migration and export manager. +CONCURRENCY = 5 # max parallel thread fetches + + +@dataclass +class MigrationProgress: + """Mutable progress state shared across concurrent tasks.""" + exported: int = 0 + skipped: int = 0 + failed: int = 0 + total_checkpoints: int = 0 + total_threads: int = 0 # set after discovery + start_time: float = field(default_factory=time.time) + errors: List[str] = field(default_factory=list) + + @property + def elapsed(self) -> float: + return time.time() - self.start_time - Handles fetching threads from source, exporting to various - destinations, and optionally importing to a target deployment. - """ + @property + def rate(self) -> float: + done = self.exported + self.skipped + self.failed + return done / self.elapsed if self.elapsed > 0 else 0 + + +class ThreadMigrator: + """Thread migration and export manager with concurrent fetching.""" def __init__( self, @@ -32,47 +52,34 @@ def __init__( api_key: str = "", source_api_key: Optional[str] = None, target_api_key: Optional[str] = None, - rate_limit_delay: float = 0.2, + rate_limit_delay: float = 0.1, + concurrency: int = CONCURRENCY, ): - """ - Initialize the migrator. - - Args: - source_url: Source LangGraph deployment URL - target_url: Target LangGraph deployment URL (for migration) - api_key: Shared API key (fallback for source/target) - source_api_key: API key for source deployment (overrides api_key) - target_api_key: API key for target deployment (overrides api_key) - rate_limit_delay: Delay between API calls (seconds) - """ self.source_url = source_url self.target_url = target_url self.source_api_key = source_api_key or api_key self.target_api_key = target_api_key or api_key self.rate_limit_delay = rate_limit_delay + self.concurrency = concurrency self._source_client: Optional[LangGraphClient] = None self._target_client: Optional[LangGraphClient] = None self._exporters: List[BaseExporter] = [] async def __aenter__(self) -> "ThreadMigrator": - """Async context manager entry.""" await self.connect() return self async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: - """Async context manager exit.""" await self.close() async def connect(self) -> None: - """Initialize clients with their respective API keys.""" if self.source_url: self._source_client = LangGraphClient(self.source_url, self.source_api_key) if self.target_url: self._target_client = LangGraphClient(self.target_url, self.target_api_key) async def close(self) -> None: - """Close all clients and exporters.""" if self._source_client: await self._source_client.close() if self._target_client: @@ -81,16 +88,7 @@ async def close(self) -> None: await exporter.close() @staticmethod - async def _retry(coro_factory, max_attempts=3, base_delay=1.0, label=""): - """ - Retry an async call with exponential backoff + jitter. - - Args: - coro_factory: Callable that returns a new coroutine each call - max_attempts: Maximum number of attempts - base_delay: Base delay in seconds (doubles each retry) - label: Label for log messages - """ + async def _retry(coro_factory, max_attempts=3, base_delay=0.8, label=""): last_error = None for attempt in range(max_attempts): try: @@ -99,129 +97,126 @@ async def _retry(coro_factory, max_attempts=3, base_delay=1.0, label=""): last_error = e if attempt < max_attempts - 1: delay = base_delay * (2 ** attempt) * (0.7 + random.random() * 0.6) - logging.warning(f"Retry {attempt + 1}/{max_attempts} for {label}: {e} (wait {delay:.1f}s)") + logger.warning(f"Retry {attempt+1}/{max_attempts} {label}: {e} ({delay:.1f}s)") await asyncio.sleep(delay) raise last_error def add_json_exporter(self, output_file: str = "threads_backup.json") -> JSONExporter: - """ - Add JSON file exporter. - - Args: - output_file: Path to output JSON file - - Returns: - The created exporter - """ exporter = JSONExporter(self.source_url or "", output_file) self._exporters.append(exporter) return exporter def add_postgres_exporter(self, database_url: str) -> PostgresExporter: - """ - Add PostgreSQL exporter. - - Args: - database_url: PostgreSQL connection URL - - Returns: - The created exporter - """ exporter = PostgresExporter(self.source_url or "", database_url) self._exporters.append(exporter) return exporter - async def fetch_all_threads( + # ── Discovery: list all thread IDs ──────────────────────────── + + async def _discover_thread_ids( self, limit: Optional[int] = None, metadata_filter: Optional[Dict[str, Any]] = None, - history_limit: Optional[int] = None, - progress_callback: Optional[Callable[[int, str], None]] = None, - ) -> List[ThreadData]: - """ - Fetch all threads from source deployment. - - Args: - limit: Maximum number of threads to fetch (None = all) - metadata_filter: Optional metadata dict for server-side filtering - history_limit: Max checkpoints per thread (None = all) - progress_callback: Optional callback(count, message) for progress updates - - Returns: - List of ThreadData objects - """ + ) -> List[Dict[str, Any]]: + """Fetch all thread summaries (ID + metadata only, no history).""" if not self._source_client: raise RuntimeError("Source URL not configured") - all_threads: List[ThreadData] = [] + all_summaries: List[Dict[str, Any]] = [] offset = 0 batch_size = 100 if limit is None else min(limit, 100) while True: - if progress_callback: - progress_callback(len(all_threads), f"Fetching threads (offset={offset})...") - - threads = await self._retry( + batch = await self._retry( lambda o=offset: self._source_client.search_threads( limit=batch_size, offset=o, metadata=metadata_filter, ), label="search_threads", ) + if not batch: + break + + all_summaries.extend(batch) - if not threads: + if limit and len(all_summaries) >= limit: + all_summaries = all_summaries[:limit] + break + if len(batch) < batch_size: break - for thread_summary in threads: - thread_id = thread_summary.get("thread_id") - if not thread_id: - continue + offset += len(batch) + await asyncio.sleep(self.rate_limit_delay) - try: - details = await self._retry( - lambda tid=thread_id: self._source_client.get_thread(tid), - label=f"get_thread({thread_id[:8]})", - ) - history = await self._retry( - lambda tid=thread_id: self._source_client.get_all_history( - tid, limit=history_limit, - ), - label=f"get_all_history({thread_id[:8]})", - ) + return all_summaries - thread_data = ThreadData( - thread_id=thread_id, - metadata=details.get("metadata", {}), - values=details.get("values", {}), - history=history, - created_at=details.get("created_at"), - updated_at=details.get("updated_at"), - ) - all_threads.append(thread_data) + # ── Fetch single thread (details + history) ─────────────────── - if progress_callback: - progress_callback( - len(all_threads), - f"Fetched thread {thread_id[:8]}... ({len(history)} checkpoints)" - ) + async def _fetch_thread( + self, + thread_id: str, + history_limit: Optional[int] = None, + ) -> ThreadData: + """Fetch one thread's details + full history.""" + details = await self._retry( + lambda: self._source_client.get_thread(thread_id), + label=f"get({thread_id[:8]})", + ) + history = await self._source_client.get_all_history( + thread_id, limit=history_limit, + ) + return ThreadData( + thread_id=thread_id, + metadata=details.get("metadata", {}), + values=details.get("values", {}), + history=history, + created_at=details.get("created_at"), + updated_at=details.get("updated_at"), + ) - except Exception as e: - if progress_callback: - progress_callback(len(all_threads), f"Skipped {thread_id[:8]}: {e}") - continue + # ── Fetch all (in-memory, for --full mode) ──────────────────── - await asyncio.sleep(self.rate_limit_delay) + async def fetch_all_threads( + self, + limit: Optional[int] = None, + metadata_filter: Optional[Dict[str, Any]] = None, + history_limit: Optional[int] = None, + progress_callback: Optional[Callable[[int, str], None]] = None, + ) -> List[ThreadData]: + """Fetch all threads concurrently.""" + summaries = await self._discover_thread_ids(limit, metadata_filter) - if limit and len(all_threads) >= limit: - return all_threads + if progress_callback: + progress_callback(0, f"Discovered {len(summaries)} threads, fetching...") - offset += len(threads) + sem = asyncio.Semaphore(self.concurrency) + results: List[Optional[ThreadData]] = [None] * len(summaries) + done_count = 0 - if len(threads) < batch_size: - break + async def fetch_one(idx: int, thread_id: str): + nonlocal done_count + async with sem: + try: + results[idx] = await self._fetch_thread(thread_id, history_limit) + done_count += 1 + if progress_callback: + cp = len(results[idx].history) if results[idx] else 0 + progress_callback(done_count, f"Fetched {thread_id[:8]}... ({cp} cp)") + except Exception as e: + done_count += 1 + logger.warning(f"Skip {thread_id[:8]}: {e}") + if progress_callback: + progress_callback(done_count, f"Skip {thread_id[:8]}: {e}") - await asyncio.sleep(self.rate_limit_delay) + tasks = [ + fetch_one(i, s.get("thread_id")) + for i, s in enumerate(summaries) + if s.get("thread_id") + ] + await asyncio.gather(*tasks) + + return [t for t in results if t is not None] - return all_threads + # ── Streaming export (fetch + write one at a time) ──────────── async def export_threads( self, @@ -233,136 +228,141 @@ async def export_threads( ) -> ExportStats: """ Export threads to all configured exporters. + Streams fetch+write concurrently to minimize memory. + """ + for exporter in self._exporters: + await exporter.connect() - Args: - threads: Pre-fetched threads (if None, fetches from source) - limit: Maximum number of threads to export - metadata_filter: Optional metadata dict for server-side filtering - history_limit: Max checkpoints per thread (None = all) - progress_callback: Optional callback for progress updates + if threads is not None: + return await self._export_prefetched(threads, progress_callback) - Returns: - Combined export statistics - """ - if threads is None: - threads = await self.fetch_all_threads( - limit=limit, - metadata_filter=metadata_filter, - history_limit=history_limit, - progress_callback=progress_callback, - ) + # Streaming: discover → concurrent fetch → sequential write + if not self._source_client: + raise RuntimeError("Source URL not configured") + + summaries = await self._discover_thread_ids(limit, metadata_filter) + total = len(summaries) + + if progress_callback: + progress_callback(0, f"Found {total} threads — exporting...") + + sem = asyncio.Semaphore(self.concurrency) + # Use a queue: fetchers produce, writer consumes + queue: asyncio.Queue[Optional[ThreadData]] = asyncio.Queue(maxsize=self.concurrency * 2) + progress = MigrationProgress(total_threads=total) + + async def fetch_worker(thread_id: str): + async with sem: + try: + td = await self._fetch_thread(thread_id, history_limit) + await queue.put(td) + except Exception as e: + progress.failed += 1 + progress.errors.append(f"{thread_id[:8]}: {e}") + logger.warning(f"Fetch failed {thread_id[:8]}: {e}") + + async def writer(): + while True: + td = await queue.get() + if td is None: # sentinel + break + for exporter in self._exporters: + await exporter.export_thread(td) + progress.exported += 1 + progress.total_checkpoints += len(td.history) + if progress_callback: + progress_callback( + progress.exported, + f"{td.thread_id[:8]} ({len(td.history)} cp) " + f"[{progress.exported}/{total}]" + ) + + # Start writer task + writer_task = asyncio.create_task(writer()) + + # Launch all fetchers + thread_ids = [s.get("thread_id") for s in summaries if s.get("thread_id")] + fetch_tasks = [asyncio.create_task(fetch_worker(tid)) for tid in thread_ids] + await asyncio.gather(*fetch_tasks) + + # Signal writer to stop + await queue.put(None) + await writer_task - # Connect all exporters for exporter in self._exporters: - await exporter.connect() + await exporter.finalize() + + return ExportStats( + threads_exported=progress.exported, + checkpoints_exported=progress.total_checkpoints, + ) - # Export each thread + async def _export_prefetched( + self, + threads: List[ThreadData], + progress_callback: Optional[Callable[[int, str], None]] = None, + ) -> ExportStats: total_checkpoints = 0 for i, thread in enumerate(threads): for exporter in self._exporters: await exporter.export_thread(thread) total_checkpoints += len(thread.history) - if progress_callback: progress_callback(i + 1, f"Exported {thread.thread_id[:8]}...") - # Finalize all exporters for exporter in self._exporters: await exporter.finalize() - # Return combined stats - stats = ExportStats( + return ExportStats( threads_exported=len(threads), checkpoints_exported=total_checkpoints, ) - return stats + + # ── Supersteps computation ──────────────────────────────────── @staticmethod def _compute_values_delta( prev_values: Dict[str, Any], curr_values: Dict[str, Any], ) -> Dict[str, Any]: - """ - Compute the delta between two checkpoint states. - - For messages (add_messages reducer): extracts only new messages by ID. - For lists of dicts with 'id' field: extracts only new items by ID. - For all other fields: includes only if changed. - """ delta: Dict[str, Any] = {} - for key, curr_val in curr_values.items(): prev_val = prev_values.get(key) - if curr_val == prev_val: continue if key == "messages": - # add_messages reducer: diff by message ID - prev_ids = { - m.get("id") for m in (prev_val or []) if isinstance(m, dict) - } - new_msgs = [ - m for m in (curr_val or []) - if isinstance(m, dict) and m.get("id") not in prev_ids - ] + prev_ids = {m.get("id") for m in (prev_val or []) if isinstance(m, dict)} + new_msgs = [m for m in (curr_val or []) if isinstance(m, dict) and m.get("id") not in prev_ids] if new_msgs: delta["messages"] = new_msgs - elif ( isinstance(curr_val, list) and curr_val and isinstance(curr_val[0], dict) and "id" in curr_val[0] ): - # List of dicts with ID → dedup-aware delta - prev_ids = { - item.get("id") - for item in (prev_val or []) - if isinstance(item, dict) - } - new_items = [ - item for item in curr_val - if isinstance(item, dict) and item.get("id") not in prev_ids - ] + prev_ids = {item.get("id") for item in (prev_val or []) if isinstance(item, dict)} + new_items = [item for item in curr_val if isinstance(item, dict) and item.get("id") not in prev_ids] if new_items: delta[key] = new_items - else: - # Scalar or non-ID list → include if changed delta[key] = curr_val - return delta @staticmethod - def _history_to_supersteps( - history: List[Dict[str, Any]], - ) -> List[Dict[str, Any]]: - """ - Convert checkpoint history to supersteps format for threads.create(). - - History from get_history() is most-recent-first. - Supersteps must be chronological (oldest-first). - - Strategy: - - If metadata.writes is available (local checkpointer), use it directly. - - Otherwise (Cloud API), compute deltas between consecutive states. - Only state-changing steps produce supersteps; middleware no-ops are skipped. - """ + def _history_to_supersteps(history: List[Dict[str, Any]]) -> List[Dict[str, Any]]: if not history: return [] chronological = list(reversed(history)) supersteps: List[Dict[str, Any]] = [] - # Check if writes data is available (first non-empty checkpoint) has_writes = any( - isinstance((h.get("metadata") or {}).get("writes"), dict) - for h in history + isinstance((h.get("metadata") or {}).get("writes"), dict) for h in history ) if has_writes: - # Direct approach: use metadata.writes for state in chronological: metadata = state.get("metadata") or {} writes = metadata.get("writes") @@ -376,50 +376,32 @@ def _history_to_supersteps( if updates: supersteps.append({"updates": updates}) else: - # Delta approach: compute diffs between consecutive states prev_values: Dict[str, Any] = {} for i, state in enumerate(chronological): curr_values = state.get("values") or {} - if i == 0: - # Include initial state as first superstep if curr_values: supersteps.append({ - "updates": [{ - "values": curr_values, - "as_node": "__start__", - }], + "updates": [{"values": curr_values, "as_node": "__start__"}], }) prev_values = curr_values continue - # as_node = what the previous checkpoint's next field says ran prev_next = chronological[i - 1].get("next", []) as_node = prev_next[0] if prev_next else "__unknown__" - delta = ThreadMigrator._compute_values_delta(prev_values, curr_values) - if delta: supersteps.append({ "updates": [{"values": delta, "as_node": as_node}], }) - prev_values = curr_values return supersteps - async def _import_thread_with_history( - self, - thread: ThreadData, - ) -> int: - """ - Import a single thread using supersteps to preserve checkpoint history. + # ── Import ──────────────────────────────────────────────────── - Returns the number of checkpoints imported. - Raises on failure so caller can handle fallback. - """ + async def _import_thread_with_history(self, thread: ThreadData) -> int: supersteps = self._history_to_supersteps(thread.history) - await self._target_client.create_thread_with_history( thread_id=thread.thread_id, metadata=thread.metadata, @@ -427,8 +409,9 @@ async def _import_thread_with_history( ) return len(supersteps) - async def _import_thread_legacy(self, thread: ThreadData) -> None: - """Fallback: create thread + update_state (no history preservation).""" + async def _import_thread_legacy( + self, thread: ThreadData, as_node: Optional[str] = None, + ) -> None: await self._target_client.create_thread( thread_id=thread.thread_id, metadata=thread.metadata, @@ -437,29 +420,16 @@ async def _import_thread_legacy(self, thread: ThreadData) -> None: await self._target_client.update_thread_state( thread_id=thread.thread_id, values=thread.values, + as_node=as_node, ) async def import_threads( self, threads: List[ThreadData], dry_run: bool = False, + legacy_terminal_node: Optional[str] = None, progress_callback: Optional[Callable[[int, str], None]] = None, ) -> Dict[str, int]: - """ - Import threads to target deployment. - - Tries supersteps-based import first to preserve checkpoint history. - Falls back to legacy create+update_state if the target API - doesn't support supersteps. - - Args: - threads: Threads to import - dry_run: If True, don't actually create threads - progress_callback: Optional callback for progress updates - - Returns: - Dictionary with created, skipped, failed, checkpoints counts - """ if not self._target_client: raise RuntimeError("Target URL not configured") @@ -474,8 +444,7 @@ async def import_threads( if progress_callback: progress_callback( i + 1, - f"[DRY-RUN] Would create {thread.thread_id[:8]}... " - f"({len(supersteps)} supersteps)" + f"[DRY-RUN] {thread.thread_id[:8]}... ({len(supersteps)} supersteps)" ) continue @@ -484,66 +453,52 @@ async def import_threads( checkpoints = await self._import_thread_with_history(thread) results["checkpoints"] += checkpoints except APIStatusError as e: - # supersteps not supported — switch to legacy for all remaining if e.status_code in (400, 422): use_legacy = True if progress_callback: - progress_callback( - i + 1, - "supersteps not supported, falling back to legacy import" - ) - await self._import_thread_legacy(thread) + progress_callback(i + 1, "supersteps unsupported → legacy mode") + await self._import_thread_legacy(thread, as_node=legacy_terminal_node) else: raise else: - await self._import_thread_legacy(thread) + await self._import_thread_legacy(thread, as_node=legacy_terminal_node) results["created"] += 1 if progress_callback: cp_count = len(thread.history) - mode = "legacy" if use_legacy else f"{cp_count} checkpoints" + mode = "legacy" if use_legacy else f"{cp_count} cp" progress_callback(i + 1, f"Created {thread.thread_id[:8]}... ({mode})") except ConflictError: results["skipped"] += 1 if progress_callback: - progress_callback(i + 1, f"Skipped (exists) {thread.thread_id[:8]}...") - + progress_callback(i + 1, f"Skip (exists) {thread.thread_id[:8]}...") except APIStatusError as e: results["failed"] += 1 if progress_callback: - progress_callback(i + 1, f"Failed {thread.thread_id[:8]}: {e}") - + progress_callback(i + 1, f"FAIL {thread.thread_id[:8]}: {e}") except Exception as e: results["failed"] += 1 if progress_callback: - progress_callback(i + 1, f"Error {thread.thread_id[:8]}: {e}") + progress_callback(i + 1, f"ERROR {thread.thread_id[:8]}: {e}") await asyncio.sleep(self.rate_limit_delay) return results + # ── Validation ──────────────────────────────────────────────── + async def validate_migration( self, check_history: bool = False, sample_thread_id: Optional[str] = None, ) -> Dict[str, Any]: - """ - Compare thread counts between source and target. - - Optionally validates checkpoint history for a sample thread. - - Returns: - Dictionary with source_count, target_count, and optionally - history_source/history_target for the sample thread. - """ source_count = 0 target_count = 0 if self._source_client: threads = await self._source_client.search_threads(limit=10000) source_count = len(threads) - if self._target_client: threads = await self._target_client.search_threads(limit=10000) target_count = len(threads) @@ -554,7 +509,6 @@ async def validate_migration( "difference": source_count - target_count, } - # Validate checkpoint history for a sample thread if check_history and sample_thread_id: if self._source_client: src_history = await self._source_client.get_all_history(sample_thread_id) diff --git a/migrate_threads.py b/migrate_threads.py index db9b72d..23d9cca 100644 --- a/migrate_threads.py +++ b/migrate_threads.py @@ -1,14 +1,9 @@ #!/usr/bin/env python3 """ -LangGraph Threads Export Tool +LangGraph Threads Migration Tool -Export threads, checkpoints, and conversation history from LangGraph Cloud. -Save to JSON files, PostgreSQL databases, or migrate to another deployment. - -Usage: - python migrate_threads.py --source-url --export-json threads.json - python migrate_threads.py --source-url --export-postgres - python migrate_threads.py --source-url --target-url --full +Export/import threads with full checkpoint history between LangGraph Cloud deployments. +Features: concurrent fetching, streaming JSON export, per-page retry, rich progress. """ import argparse @@ -21,18 +16,37 @@ from dotenv import load_dotenv from rich.console import Console from rich.panel import Panel -from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn +from rich.progress import ( + Progress, SpinnerColumn, TextColumn, BarColumn, + TimeElapsedColumn, MofNCompleteColumn, +) from rich.table import Table from langgraph_export import ThreadMigrator from langgraph_export.exporters import JSONExporter -# Load environment variables load_dotenv() - console = Console() +def make_progress(**kwargs) -> Progress: + """Create a rich Progress bar with consistent columns.""" + return Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(bar_width=30), + MofNCompleteColumn(), + TextColumn("•"), + TimeElapsedColumn(), + TextColumn("•"), + TextColumn("{task.fields[detail]}"), + console=console, + **kwargs, + ) + + +# ── Export to JSON ──────────────────────────────────────────────── + async def run_export_json( source_url: str, source_api_key: str, @@ -41,44 +55,63 @@ async def run_export_json( metadata_filter: Optional[dict] = None, history_limit: Optional[int] = None, ) -> None: - """Export threads to JSON file.""" console.print(Panel.fit( - "[bold cyan]Exporting threads to JSON[/bold cyan]", + "[bold cyan]Export threads to JSON[/bold cyan]", border_style="cyan", )) - if metadata_filter: - console.print(f"[cyan]Metadata filter:[/cyan] {metadata_filter}") + console.print(f" [dim]Filter:[/dim] {metadata_filter}") + console.print(f" [dim]Output:[/dim] {output_file}") + console.print(f" [dim]Concurrency:[/dim] 5 parallel fetches") + console.print() async with ThreadMigrator(source_url=source_url, source_api_key=source_api_key) as migrator: json_exporter = migrator.add_json_exporter(output_file) - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - console=console, - ) as progress: - task = progress.add_task("[cyan]Exporting...", total=None) - - def update_progress(count: int, message: str) -> None: - progress.update(task, completed=count, description=f"[cyan]{message}") + with make_progress() as progress: + # Phase 1: discover + discover_task = progress.add_task( + "[cyan]Discovering threads...", total=None, detail="searching..." + ) limit = 1 if test_single else None + summaries = await migrator._discover_thread_ids(limit, metadata_filter) + total = len(summaries) + progress.update(discover_task, completed=total, total=total, detail=f"{total} found") + + # Phase 2: fetch + export (streaming) + export_task = progress.add_task( + "[green]Fetching & writing", total=total, detail="starting..." + ) + + exported = 0 + + def on_progress(count: int, message: str) -> None: + nonlocal exported + exported = count + progress.update(export_task, completed=count, detail=message) + stats = await migrator.export_threads( limit=limit, metadata_filter=metadata_filter, history_limit=history_limit, - progress_callback=update_progress, + progress_callback=on_progress, ) - console.print(f"\n[green]✓[/green] Threads exported: {stats.threads_exported}") - console.print(f"[green]✓[/green] Checkpoints exported: {stats.checkpoints_exported}") - console.print(f"[green]✓[/green] Output file: [bold]{output_file}[/bold]") - console.print(f"[green]✓[/green] File size: {json_exporter.get_file_size_mb():.2f} MB") + # Summary + console.print() + table = Table(title="Export Complete", show_header=False, border_style="green") + table.add_column("Metric", style="cyan") + table.add_column("Value", justify="right", style="bold") + table.add_row("Threads", str(stats.threads_exported)) + table.add_row("Checkpoints", str(stats.checkpoints_exported)) + table.add_row("File", output_file) + table.add_row("Size", f"{json_exporter.get_file_size_mb():.2f} MB") + console.print(table) +# ── Export to PostgreSQL ────────────────────────────────────────── + async def run_export_postgres( source_url: str, source_api_key: str, @@ -88,104 +121,85 @@ async def run_export_postgres( metadata_filter: Optional[dict] = None, history_limit: Optional[int] = None, ) -> None: - """Export threads to PostgreSQL (and JSON backup).""" console.print(Panel.fit( - "[bold cyan]Exporting threads to PostgreSQL[/bold cyan]", + "[bold cyan]Export threads to PostgreSQL + JSON[/bold cyan]", border_style="cyan", )) - if metadata_filter: - console.print(f"[cyan]Metadata filter:[/cyan] {metadata_filter}") - async with ThreadMigrator(source_url=source_url, source_api_key=source_api_key) as migrator: json_exporter = migrator.add_json_exporter(output_file) pg_exporter = migrator.add_postgres_exporter(database_url) - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - console=console, - ) as progress: - task = progress.add_task("[cyan]Exporting...", total=None) + with make_progress() as progress: + task = progress.add_task("[cyan]Exporting...", total=None, detail="starting...") - def update_progress(count: int, message: str) -> None: - progress.update(task, completed=count, description=f"[cyan]{message}") + def on_progress(count: int, message: str) -> None: + progress.update(task, completed=count, detail=message) limit = 1 if test_single else None stats = await migrator.export_threads( limit=limit, metadata_filter=metadata_filter, history_limit=history_limit, - progress_callback=update_progress, + progress_callback=on_progress, ) - # Get PostgreSQL stats db_stats = await pg_exporter.get_database_stats() - - # Display results - console.print(f"\n[green]✓[/green] Threads exported: {stats.threads_exported}") - console.print(f"[green]✓[/green] Checkpoints exported: {stats.checkpoints_exported}") - console.print(f"[green]✓[/green] JSON backup: [bold]{output_file}[/bold] ({json_exporter.get_file_size_mb():.2f} MB)") + console.print(f"\n[green]✓[/green] Threads: {stats.threads_exported}") + console.print(f"[green]✓[/green] JSON: {output_file} ({json_exporter.get_file_size_mb():.2f} MB)") console.print(f"[green]✓[/green] PostgreSQL: {db_stats['threads']} threads, {db_stats['checkpoints']} checkpoints") +# ── Import from JSON ────────────────────────────────────────────── + async def run_import_json( target_url: str, target_api_key: str, input_file: str, dry_run: bool = False, + legacy_terminal_node: Optional[str] = None, ) -> None: - """Import threads from JSON file.""" console.print(Panel.fit( - "[bold green]Importing threads from JSON[/bold green]", + "[bold green]Import threads from JSON[/bold green]", border_style="green", )) - # Load threads from JSON threads = JSONExporter.load_threads(input_file) info = JSONExporter.get_export_info(input_file) + total = len(threads) - console.print(f"[cyan]Source:[/cyan] {info['source']}") - console.print(f"[cyan]Export date:[/cyan] {info['export_date']}") - console.print(f"[cyan]Threads to import:[/cyan] {len(threads)}") - + console.print(f" [dim]Source:[/dim] {info['source']}") + console.print(f" [dim]Date:[/dim] {info['export_date']}") + console.print(f" [dim]Threads:[/dim] {total}") if dry_run: - console.print("\n[yellow]⚠ DRY-RUN MODE: No changes will be made[/yellow]") + console.print(" [yellow]DRY-RUN — no changes[/yellow]") + console.print() async with ThreadMigrator(target_url=target_url, target_api_key=target_api_key) as migrator: - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - console=console, - ) as progress: - task = progress.add_task("[green]Importing...", total=len(threads)) - - def update_progress(count: int, message: str) -> None: - progress.update(task, completed=count, description=f"[green]{message}") + with make_progress() as progress: + task = progress.add_task("[green]Importing...", total=total, detail="starting...") + + def on_progress(count: int, message: str) -> None: + progress.update(task, completed=count, detail=message) results = await migrator.import_threads( - threads, - dry_run=dry_run, - progress_callback=update_progress, + threads, dry_run=dry_run, + legacy_terminal_node=legacy_terminal_node, + progress_callback=on_progress, ) - # Display results - table = Table(title="Import Summary") - table.add_column("Status", style="cyan") - table.add_column("Count", justify="right", style="magenta") - - table.add_row("Created", str(results["created"])) - table.add_row("Skipped (exists)", str(results["skipped"])) - table.add_row("Failed", str(results["failed"])) - table.add_row("Checkpoints", str(results.get("checkpoints", 0))) - table.add_row("Total threads", str(len(threads))) + console.print() + table = Table(title="Import Summary", border_style="green") + table.add_column("Status", style="cyan") + table.add_column("Count", justify="right", style="bold") + table.add_row("Created", str(results["created"])) + table.add_row("Skipped (exists)", str(results["skipped"])) + table.add_row("Failed", str(results["failed"])) + table.add_row("Checkpoints (supersteps)", str(results.get("checkpoints", 0))) + console.print(table) - console.print(table) +# ── Full migration ──────────────────────────────────────────────── async def run_full_migration( source_url: str, @@ -197,15 +211,15 @@ async def run_full_migration( test_single: bool = False, metadata_filter: Optional[dict] = None, history_limit: Optional[int] = None, + legacy_terminal_node: Optional[str] = None, ) -> None: - """Full migration: export + import + validate.""" + # Phase 1: Export console.print(Panel.fit( - "[bold cyan]Phase 1: Export threads from source[/bold cyan]", + "[bold cyan]Phase 1 — Export from source[/bold cyan]", border_style="cyan", )) - if metadata_filter: - console.print(f"[cyan]Metadata filter:[/cyan] {metadata_filter}") + console.print(f" [dim]Filter:[/dim] {metadata_filter}") async with ThreadMigrator( source_url=source_url, @@ -215,118 +229,82 @@ async def run_full_migration( ) as migrator: json_exporter = migrator.add_json_exporter(backup_file) - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - console=console, - ) as progress: - task = progress.add_task("[cyan]Exporting...", total=None) - - def update_progress(count: int, message: str) -> None: - progress.update(task, completed=count, description=f"[cyan]{message}") + with make_progress() as progress: + task = progress.add_task("[cyan]Exporting...", total=None, detail="discovering...") limit = 1 if test_single else None threads = await migrator.fetch_all_threads( limit=limit, metadata_filter=metadata_filter, history_limit=history_limit, - progress_callback=update_progress, + progress_callback=lambda c, m: progress.update(task, completed=c, detail=m), ) - # Export to JSON backup await json_exporter.connect() for thread in threads: await json_exporter.export_thread(thread) await json_exporter.finalize() - console.print(f"\n[green]✓[/green] Exported {len(threads)} threads") - console.print(f"[green]✓[/green] Backup: {backup_file} ({json_exporter.get_file_size_mb():.2f} MB)") + console.print(f" [green]✓[/green] {len(threads)} threads → {backup_file} ({json_exporter.get_file_size_mb():.2f} MB)") # Phase 2: Import console.print(Panel.fit( - f"[bold green]Phase 2: Import {len(threads)} threads to target[/bold green]", + f"[bold green]Phase 2 — Import {len(threads)} threads to target[/bold green]", border_style="green", )) - if dry_run: - console.print("[yellow]⚠ DRY-RUN MODE: No changes will be made[/yellow]") + console.print(" [yellow]DRY-RUN — no changes[/yellow]") - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - console=console, - ) as progress: - task = progress.add_task("[green]Importing...", total=len(threads)) - - def update_progress2(count: int, message: str) -> None: - progress.update(task, completed=count, description=f"[green]{message}") + with make_progress() as progress: + task = progress.add_task("[green]Importing...", total=len(threads), detail="starting...") results = await migrator.import_threads( threads, dry_run=dry_run, - progress_callback=update_progress2, + legacy_terminal_node=legacy_terminal_node, + progress_callback=lambda c, m: progress.update(task, completed=c, detail=m), ) - # Display import results - table = Table(title="Import Summary") + table = Table(title="Import Summary", border_style="green") table.add_column("Status", style="cyan") - table.add_column("Count", justify="right", style="magenta") - + table.add_column("Count", justify="right", style="bold") table.add_row("Created", str(results["created"])) - table.add_row("Skipped (exists)", str(results["skipped"])) + table.add_row("Skipped", str(results["skipped"])) table.add_row("Failed", str(results["failed"])) table.add_row("Checkpoints", str(results.get("checkpoints", 0))) - console.print(table) # Phase 3: Validate console.print(Panel.fit( - "[bold blue]Phase 3: Validate migration[/bold blue]", + "[bold blue]Phase 3 — Validate[/bold blue]", border_style="blue", )) - - # For test-single, also validate checkpoint history sample_id = threads[0].thread_id if test_single and threads else None validation = await migrator.validate_migration( check_history=bool(sample_id), sample_thread_id=sample_id, ) - table = Table(title="Validation") + table = Table(title="Validation", border_style="blue") table.add_column("Metric", style="cyan") - table.add_column("Source", justify="right", style="magenta") - table.add_column("Target", justify="right", style="magenta") - - table.add_row( - "Threads", - str(validation["source_count"]), - str(validation["target_count"]), - ) + table.add_column("Source", justify="right") + table.add_column("Target", justify="right") + table.add_row("Threads", str(validation["source_count"]), str(validation["target_count"])) if "history_source" in validation: table.add_row( - f"Checkpoints ({sample_id[:8]}...)", + f"History ({sample_id[:8]}...)", str(validation["history_source"]), str(validation["history_target"]), ) - console.print(table) if validation["difference"] <= 0: - console.print("[green]✓ Thread count validated![/green]") + console.print("[green]✓ Thread count validated[/green]") else: - console.print(f"[yellow]⚠ Target has {validation['difference']} fewer threads[/yellow]") + console.print(f"[yellow]⚠ {validation['difference']} threads missing on target[/yellow]") - if "history_source" in validation: - if validation["history_source"] == validation["history_target"]: - console.print("[green]✓ Checkpoint history validated![/green]") - else: - diff = validation["history_source"] - validation["history_target"] - console.print(f"[yellow]⚠ History mismatch: {diff} checkpoints missing[/yellow]") +# ── Validate only ───────────────────────────────────────────────── async def run_validate( source_url: str, @@ -334,11 +312,7 @@ async def run_validate( source_api_key: str, target_api_key: str, ) -> None: - """Validate migration by comparing thread counts.""" - console.print(Panel.fit( - "[bold blue]Validating migration[/bold blue]", - border_style="blue", - )) + console.print(Panel.fit("[bold blue]Validate migration[/bold blue]", border_style="blue")) async with ThreadMigrator( source_url=source_url, @@ -348,40 +322,34 @@ async def run_validate( ) as migrator: validation = await migrator.validate_migration() - table = Table(title="Source vs Target") - table.add_column("Deployment", style="cyan") - table.add_column("Threads", justify="right", style="magenta") - table.add_column("Status", style="green") - - table.add_row("Source", str(validation["source_count"]), "✓") - status = "✓" if validation["difference"] <= 0 else "⚠" - table.add_row("Target", str(validation["target_count"]), status) + table = Table(title="Source vs Target", border_style="blue") + table.add_column("Deployment", style="cyan") + table.add_column("Threads", justify="right", style="bold") + table.add_row("Source", str(validation["source_count"])) + table.add_row("Target", str(validation["target_count"])) + console.print(table) - console.print(table) + if validation["difference"] <= 0: + console.print("[green]✓ Migration validated[/green]") + else: + console.print(f"[yellow]⚠ {validation['difference']} threads missing on target[/yellow]") - if validation["difference"] <= 0: - console.print("[green]✓ Migration validated successfully![/green]") - else: - console.print(f"[yellow]⚠ Target has {validation['difference']} fewer threads[/yellow]") +# ── CLI ─────────────────────────────────────────────────────────── def main() -> None: - """Main entry point.""" parser = argparse.ArgumentParser( - description="LangGraph Threads Export Tool", + description="LangGraph Threads Migration Tool", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Export to JSON python migrate_threads.py --source-url https://my.langgraph.app --export-json backup.json - # Export to PostgreSQL - python migrate_threads.py --source-url https://my.langgraph.app --export-postgres - # Full migration (same org) - python migrate_threads.py --source-url https://prod.langgraph.app --target-url https://dev.langgraph.app --full + python migrate_threads.py --source-url https://src.langgraph.app --target-url https://tgt.langgraph.app --full - # Full migration (cross-org, separate API keys) + # Full migration (cross-org) python migrate_threads.py --source-url https://org1.langgraph.app --target-url https://org2.langgraph.app \\ --source-api-key lsv2_sk_... --target-api-key lsv2_sk_... --full @@ -392,10 +360,10 @@ def main() -> None: parser.add_argument("--source-url", help="Source LangGraph deployment URL") parser.add_argument("--target-url", help="Target LangGraph deployment URL") - parser.add_argument("--api-key", help="Shared API key fallback (or set LANGSMITH_API_KEY)") - parser.add_argument("--source-api-key", help="Source API key (or set LANGSMITH_SOURCE_API_KEY)") - parser.add_argument("--target-api-key", help="Target API key (or set LANGSMITH_TARGET_API_KEY)") - parser.add_argument("--database-url", help="PostgreSQL URL (or set DATABASE_URL)") + parser.add_argument("--api-key", help="Shared API key (or LANGSMITH_API_KEY)") + parser.add_argument("--source-api-key", help="Source API key (or LANGSMITH_SOURCE_API_KEY)") + parser.add_argument("--target-api-key", help="Target API key (or LANGSMITH_TARGET_API_KEY)") + parser.add_argument("--database-url", help="PostgreSQL URL (or DATABASE_URL)") parser.add_argument("--export-json", metavar="FILE", help="Export to JSON file") parser.add_argument("--export-postgres", action="store_true", help="Export to PostgreSQL") @@ -407,24 +375,20 @@ def main() -> None: parser.add_argument("--dry-run", action="store_true", help="Simulation mode") parser.add_argument("--test-single", action="store_true", help="Test with single thread") parser.add_argument("--metadata-filter", metavar="JSON", - help="Filter threads by metadata (JSON string, e.g. '{\"workspace_id\": 4}')") + help="Filter threads by metadata (JSON, e.g. '{\"workspace_id\": 4}')") parser.add_argument("--history-limit", type=int, default=None, help="Max checkpoints per thread (default: all)") + parser.add_argument("--concurrency", type=int, default=5, + help="Parallel thread fetches (default: 5)") + parser.add_argument("--legacy-terminal-node", metavar="NODE", + help="Graph terminal node name for legacy imports (sets next=[] via as_node)") args = parser.parse_args() - # Resolve API keys: CLI flag > specific env var > shared CLI flag > shared env var + # Resolve API keys shared_key = args.api_key or os.getenv("LANGSMITH_API_KEY") - source_api_key = ( - args.source_api_key - or os.getenv("LANGSMITH_SOURCE_API_KEY") - or shared_key - ) - target_api_key = ( - args.target_api_key - or os.getenv("LANGSMITH_TARGET_API_KEY") - or shared_key - ) + source_api_key = args.source_api_key or os.getenv("LANGSMITH_SOURCE_API_KEY") or shared_key + target_api_key = args.target_api_key or os.getenv("LANGSMITH_TARGET_API_KEY") or shared_key database_url = args.database_url or os.getenv("DATABASE_URL") # Parse metadata filter @@ -436,117 +400,89 @@ def main() -> None: console.print("[red]✗ --metadata-filter must be a JSON object[/red]") sys.exit(1) except json.JSONDecodeError as e: - console.print(f"[red]✗ Invalid JSON in --metadata-filter: {e}[/red]") + console.print(f"[red]✗ Invalid JSON: {e}[/red]") sys.exit(1) - history_limit = args.history_limit - - # Validate keys based on the command + # Validate keys needs_source = bool(args.export_json or args.export_postgres or args.full or args.validate) needs_target = bool(args.import_json or args.full or args.validate) if needs_source and not source_api_key: - console.print("[red]✗ Source API key required (--source-api-key or LANGSMITH_SOURCE_API_KEY or LANGSMITH_API_KEY)[/red]") + console.print("[red]✗ Source API key required[/red]") sys.exit(1) if needs_target and not target_api_key: - console.print("[red]✗ Target API key required (--target-api-key or LANGSMITH_TARGET_API_KEY or LANGSMITH_API_KEY)[/red]") + console.print("[red]✗ Target API key required[/red]") sys.exit(1) - # Display header + # Header console.print(Panel.fit( - "[bold magenta]🔄 LangGraph Threads Export Tool[/bold magenta]", + "[bold magenta]LangGraph Threads Migration Tool[/bold magenta]", border_style="magenta", )) - - # Show key info (cross-org vs same-org) if needs_source and needs_target and source_api_key != target_api_key: - console.print("[cyan]ℹ Cross-org migration: using separate API keys for source and target[/cyan]\n") - + console.print("[cyan]ℹ Cross-org: separate API keys[/cyan]") if args.test_single: - console.print("[yellow]⚠ TEST MODE: Single thread only[/yellow]\n") + console.print("[yellow]⚠ TEST MODE: single thread[/yellow]") + console.print() - # Run appropriate command + # Dispatch try: if args.export_json: if not args.source_url: - console.print("[red]✗ --source-url required[/red]") - sys.exit(1) + console.print("[red]✗ --source-url required[/red]"); sys.exit(1) asyncio.run(run_export_json( - source_url=args.source_url, - source_api_key=source_api_key, - output_file=args.export_json, - test_single=args.test_single, - metadata_filter=metadata_filter, - history_limit=history_limit, + source_url=args.source_url, source_api_key=source_api_key, + output_file=args.export_json, test_single=args.test_single, + metadata_filter=metadata_filter, history_limit=args.history_limit, )) - elif args.export_postgres: if not args.source_url: - console.print("[red]✗ --source-url required[/red]") - sys.exit(1) + console.print("[red]✗ --source-url required[/red]"); sys.exit(1) if not database_url: - console.print("[red]✗ --database-url or DATABASE_URL required[/red]") - sys.exit(1) + console.print("[red]✗ --database-url required[/red]"); sys.exit(1) asyncio.run(run_export_postgres( - source_url=args.source_url, - source_api_key=source_api_key, - database_url=database_url, - output_file=args.backup_file, - test_single=args.test_single, - metadata_filter=metadata_filter, - history_limit=history_limit, + source_url=args.source_url, source_api_key=source_api_key, + database_url=database_url, output_file=args.backup_file, + test_single=args.test_single, metadata_filter=metadata_filter, + history_limit=args.history_limit, )) - elif args.import_json: if not args.target_url: - console.print("[red]✗ --target-url required[/red]") - sys.exit(1) + console.print("[red]✗ --target-url required[/red]"); sys.exit(1) asyncio.run(run_import_json( - target_url=args.target_url, - target_api_key=target_api_key, - input_file=args.import_json, - dry_run=args.dry_run, + target_url=args.target_url, target_api_key=target_api_key, + input_file=args.import_json, dry_run=args.dry_run, + legacy_terminal_node=args.legacy_terminal_node, )) - elif args.full: if not args.source_url or not args.target_url: - console.print("[red]✗ Both --source-url and --target-url required[/red]") - sys.exit(1) + console.print("[red]✗ Both --source-url and --target-url required[/red]"); sys.exit(1) asyncio.run(run_full_migration( - source_url=args.source_url, - target_url=args.target_url, - source_api_key=source_api_key, - target_api_key=target_api_key, - backup_file=args.backup_file, - dry_run=args.dry_run, - test_single=args.test_single, - metadata_filter=metadata_filter, - history_limit=history_limit, + source_url=args.source_url, target_url=args.target_url, + source_api_key=source_api_key, target_api_key=target_api_key, + backup_file=args.backup_file, dry_run=args.dry_run, + test_single=args.test_single, metadata_filter=metadata_filter, + history_limit=args.history_limit, + legacy_terminal_node=args.legacy_terminal_node, )) - elif args.validate: if not args.source_url or not args.target_url: - console.print("[red]✗ Both --source-url and --target-url required[/red]") - sys.exit(1) + console.print("[red]✗ Both URLs required[/red]"); sys.exit(1) asyncio.run(run_validate( - source_url=args.source_url, - target_url=args.target_url, - source_api_key=source_api_key, - target_api_key=target_api_key, + source_url=args.source_url, target_url=args.target_url, + source_api_key=source_api_key, target_api_key=target_api_key, )) - else: parser.print_help() - console.print("\n[yellow]⚠ Specify an action[/yellow]") sys.exit(1) except KeyboardInterrupt: console.print("\n[yellow]⚠ Interrupted[/yellow]") sys.exit(1) except Exception as e: - console.print(f"\n[red]✗ Error: {e}[/red]") + console.print(f"\n[red]✗ {e}[/red]") import traceback - console.print(traceback.format_exc()) + console.print(f"[dim]{traceback.format_exc()}[/dim]") sys.exit(1)