diff --git a/.env.example b/.env.example index a187321..ac56b4b 100644 --- a/.env.example +++ b/.env.example @@ -1,8 +1,13 @@ -# LangSmith API Key (required) +# LangSmith API Key — shared fallback for source and target # Get your key from: https://smith.langchain.com/ → Settings → API Keys # Use a Service Key (lsv2_sk_...) for deployment access LANGSMITH_API_KEY=lsv2_sk_your_api_key_here +# For cross-org migration: separate keys for source and target +# These override LANGSMITH_API_KEY when set +# LANGSMITH_SOURCE_API_KEY=lsv2_sk_source_org_key +# LANGSMITH_TARGET_API_KEY=lsv2_sk_target_org_key + # PostgreSQL Database URL (optional, for --export-postgres) # Example: postgresql://user:password@localhost:5432/dbname DATABASE_URL= diff --git a/README.md b/README.md index c3662fb..d43a1e0 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ -# LangGraph Threads Export Tool +# LangGraph Threads Migration Tool -A Python tool to export threads, checkpoints, and conversation history from LangGraph Cloud deployments. Save your data to JSON files, PostgreSQL databases, or migrate directly to another deployment. +A Python tool to export, backup, and migrate threads with full checkpoint history between LangGraph Cloud deployments. ## Why This Tool? LangGraph Cloud stores your conversation threads and checkpoints, but there's no built-in way to: - **Backup your data** before deleting a deployment -- **Migrate conversations** between environments (prod → dev) +- **Migrate conversations** between environments (prod → staging, or across orgs) - **Store threads in your own database** for analytics or compliance - **Download conversation history** as JSON for processing @@ -14,34 +14,26 @@ This tool solves all of these problems. ## Features -- **Export to JSON** - Download all threads and checkpoints as a backup file -- **Export to PostgreSQL** - Store threads in your own database with proper schema -- **Migrate between deployments** - Transfer threads from one LangGraph Cloud deployment to another -- **Preserve everything** - Thread IDs, metadata (including `owner` for multi-tenancy), checkpoints, and conversation history -- **Test mode** - Export/migrate a single thread first to verify everything works -- **Dry-run mode** - Preview changes without making any modifications -- **Progress tracking** - Real-time progress bars and detailed summaries - -## Use Cases - -| Scenario | Command | -|----------|---------| -| Backup before deleting deployment | `--export-json backup.json` | -| Cost optimization (expensive → cheaper deployment) | `--full` | -| Store in your own PostgreSQL | `--export-postgres` | -| Environment migration (staging → prod) | `--migrate` | -| Disaster recovery | `--import-json backup.json` | +- **Export to JSON** — Streaming writes, memory-efficient even for multi-GB exports +- **Export to PostgreSQL** — Store threads in your own database with proper schema +- **Migrate between deployments** — Full migration with supersteps (preserves checkpoint chains) and automatic legacy fallback +- **Concurrent fetching** — Configurable parallelism (default: 5 threads) for fast exports +- **Per-page retry** — Automatic retries with exponential backoff on paginated API calls +- **History pagination** — Correct cursor format (`{"configurable": {"checkpoint_id": ...}}`) for complete checkpoint retrieval +- **Legacy import fix** — `--legacy-terminal-node` sets `next=[]` on threads imported via fallback mode +- **Rich progress bars** — Real-time progress with thread counts, elapsed time, and per-thread details +- **Metadata filtering** — Export only threads matching specific metadata (e.g., by workspace) +- **Dry-run & test modes** — Preview changes or test with a single thread before full operations +- **Cross-org support** — Separate API keys for source and target deployments ## Installation ```bash -# Clone the repository -git clone https://github.com/YOUR_USERNAME/langgraph-threads-migration.git +git clone https://github.com/farouk09/langgraph-threads-migration.git cd langgraph-threads-migration # Using uv (recommended) -uv venv -source .venv/bin/activate +uv venv && source .venv/bin/activate uv pip install -r requirements.txt # Or using pip @@ -50,8 +42,6 @@ pip install -r requirements.txt ## Configuration -Create a `.env` file: - ```bash cp .env.example .env ``` @@ -59,9 +49,13 @@ cp .env.example .env Edit `.env` with your credentials: ```bash -# Required: LangSmith API key +# Required: LangSmith API key (shared fallback for source and target) LANGSMITH_API_KEY=lsv2_sk_your_api_key_here +# For cross-org migration: separate keys override the shared key +# LANGSMITH_SOURCE_API_KEY=lsv2_sk_source_org_key +# LANGSMITH_TARGET_API_KEY=lsv2_sk_target_org_key + # Optional: PostgreSQL connection URL (for --export-postgres) DATABASE_URL=postgresql://user:password@localhost:5432/dbname ``` @@ -70,9 +64,7 @@ DATABASE_URL=postgresql://user:password@localhost:5432/dbname ## Usage -### Export to JSON file - -Download all threads and checkpoints as a backup: +### Export to JSON ```bash python migrate_threads.py \ @@ -82,42 +74,60 @@ python migrate_threads.py \ ### Export to PostgreSQL -Store threads in your own database: - ```bash python migrate_threads.py \ --source-url https://my-deployment.langgraph.app \ --export-postgres ``` -This creates two tables: -- `langgraph_threads` - Thread metadata and current state -- `langgraph_checkpoints` - Full checkpoint history - -### Migrate between deployments - -Transfer all threads from one deployment to another: +### Full migration (export + import + validate) ```bash +# Same org (shared API key) python migrate_threads.py \ - --source-url https://my-prod.langgraph.app \ - --target-url https://my-dev.langgraph.app \ + --source-url https://old-deploy.langgraph.app \ + --target-url https://new-deploy.langgraph.app \ + --full + +# Cross-org (separate API keys) +python migrate_threads.py \ + --source-url https://org1.langgraph.app \ + --target-url https://org2.langgraph.app \ + --source-api-key lsv2_sk_source... \ + --target-api-key lsv2_sk_target... \ --full ``` ### Import from JSON -Restore threads from a backup file: - ```bash python migrate_threads.py \ --target-url https://my-deployment.langgraph.app \ --import-json threads_backup.json ``` -### Test with a single thread first +### Import with legacy terminal node fix + +When importing threads that fall back to legacy mode (no supersteps), the thread's `next` field may point to the graph's entry node instead of being empty. Use `--legacy-terminal-node` to specify your graph's terminal node, which sets `next=[]` so threads are continuable: + +```bash +python migrate_threads.py \ + --target-url https://my-deployment.langgraph.app \ + --import-json backup.json \ + --legacy-terminal-node "MyLastNode.after_handler" +``` -Always recommended before a full operation: +### Filter by metadata + +```bash +# Export threads for a specific workspace +python migrate_threads.py \ + --source-url https://my-deployment.langgraph.app \ + --export-json workspace_4.json \ + --metadata-filter '{"workspace_id": 4}' +``` + +### Test with a single thread first ```bash python migrate_threads.py \ @@ -132,23 +142,48 @@ python migrate_threads.py \ |----------|-------------| | `--source-url` | Source LangGraph Cloud deployment URL | | `--target-url` | Target LangGraph Cloud deployment URL | -| `--api-key` | LangSmith API key (or set in `.env`) | -| `--database-url` | PostgreSQL URL (or set in `.env`) | +| `--api-key` | Shared API key fallback (or `LANGSMITH_API_KEY` env var) | +| `--source-api-key` | Source API key for cross-org (or `LANGSMITH_SOURCE_API_KEY`) | +| `--target-api-key` | Target API key for cross-org (or `LANGSMITH_TARGET_API_KEY`) | +| `--database-url` | PostgreSQL URL (or `DATABASE_URL` env var) | | `--export-json FILE` | Export threads to JSON file | | `--export-postgres` | Export threads to PostgreSQL database | | `--import-json FILE` | Import threads from JSON file | -| `--migrate` | Migrate threads (export + import) | | `--full` | Full migration (export + import + validate) | | `--validate` | Compare source vs target thread counts | | `--dry-run` | Simulation mode (no changes made) | | `--test-single` | Process only one thread (for testing) | +| `--metadata-filter JSON` | Filter threads by metadata (JSON object) | +| `--history-limit N` | Max checkpoints per thread (default: all) | +| `--concurrency N` | Parallel thread fetches (default: 5) | +| `--legacy-terminal-node NODE` | Graph terminal node name for legacy imports (sets `next=[]`) | +| `--backup-file FILE` | Backup file path (default: `threads_backup.json`) | + +## Import Strategies + +The tool uses two import strategies, with automatic fallback: + +### 1. Supersteps (preferred) +Replays state changes via `threads.create(supersteps=...)`, preserving the full checkpoint chain. This enables time-travel operations on the target deployment. + +### 2. Legacy fallback +When supersteps fail (e.g., due to incompatible serialized objects in old checkpoints), the tool falls back to `create_thread()` + `update_thread_state()`. This preserves the final state but creates only a single checkpoint. + +**Known issue with legacy import**: Without `--legacy-terminal-node`, threads imported in legacy mode may have `next=['SomeMiddleware.before_handler']` instead of `next=[]`, making them non-continuable. The flag fixes this by telling LangGraph which node "last ran". + +## Key Bug Fixes (vs upstream) + +### History pagination cursor format +The LangGraph Cloud API expects the `before` cursor in the format `{"configurable": {"checkpoint_id": "..."}}`, not `{"checkpoint_id": "..."}`. The incorrect format caused silent 500 errors on every page after the first, resulting in incomplete exports. This fix is critical for any thread with more than 100 checkpoints. + +### JSON parsing of agent messages +Agent messages may contain unescaped control characters (e.g., `\n` inside JSON strings). The exporter now uses `strict=False` when loading JSON backups to handle these correctly. ## PostgreSQL Schema When using `--export-postgres`, the tool creates: ```sql --- Threads table CREATE TABLE langgraph_threads ( id SERIAL PRIMARY KEY, thread_id VARCHAR(255) UNIQUE NOT NULL, @@ -160,7 +195,6 @@ CREATE TABLE langgraph_threads ( exported_at TIMESTAMP DEFAULT NOW() ); --- Checkpoints table CREATE TABLE langgraph_checkpoints ( id SERIAL PRIMARY KEY, thread_id VARCHAR(255) REFERENCES langgraph_threads(thread_id), @@ -173,73 +207,46 @@ CREATE TABLE langgraph_checkpoints ( ); ``` -## Example Output - -``` -╭────────────────────────────────────────╮ -│ 🔄 LangGraph Threads Export Tool │ -╰────────────────────────────────────────╯ - -╭─────────────────────────────────────────╮ -│ Phase 1: Export threads from source │ -╰─────────────────────────────────────────╯ -✓ 66 threads found -✓ JSON backup saved: threads_backup.json -✓ Size: 29.30 MB -✓ Total checkpoints exported: 842 -✓ PostgreSQL: 66 threads, 842 checkpoints -``` - -## Important Notes - -### Authentication - -If your LangGraph deployment uses custom authentication (e.g., Auth0), you may need to temporarily disable it during export: - -```json -// langgraph.json - temporarily set auth to null -{ - "auth": null -} -``` - -Remember to re-enable authentication after! - -### Multi-tenancy - -The tool preserves `metadata.owner`, so each user will only see their own threads after migration. - -### Rate Limiting - -Built-in delays (0.2-0.3s) prevent API overload. For large exports (1000+ threads), consider running during off-peak hours. - -## Troubleshooting - -| Error | Solution | -|-------|----------| -| `PermissionDeniedError` | Use Service Key (`lsv2_sk_...`), not Personal Token | -| `ConflictError (409)` | Thread already exists (automatically skipped) | -| `asyncpg not installed` | Run `pip install asyncpg` for PostgreSQL support | - ## Project Structure ``` langgraph-threads-migration/ -├── migrate_threads.py # CLI entry point -├── langgraph_export/ # Main package +├── migrate_threads.py # CLI entry point (Rich progress bars) +├── langgraph_export/ │ ├── __init__.py -│ ├── client.py # LangGraph SDK wrapper -│ ├── migrator.py # Thread migration orchestrator +│ ├── client.py # LangGraph SDK wrapper (per-page retry, cursor fix) +│ ├── migrator.py # Migration orchestrator (concurrent fetch, supersteps, legacy) │ ├── models.py # SQLAlchemy models for PostgreSQL -│ └── exporters/ # Export backends +│ └── exporters/ │ ├── base.py # Abstract base exporter -│ ├── json_exporter.py # JSON file export +│ ├── json_exporter.py # Streaming JSON export │ └── postgres_exporter.py # PostgreSQL export ├── requirements.txt ├── .env.example └── LICENSE ``` +## Troubleshooting + +| Error | Solution | +|-------|----------| +| `PermissionDeniedError` | Use Service Key (`lsv2_sk_...`), not Personal Token | +| `ConflictError (409)` | Thread already exists on target (automatically skipped) | +| `asyncpg not installed` | Run `pip install asyncpg` for PostgreSQL support | +| `KeyError: 'configurable'` | Server-side error from wrong pagination cursor — fixed in this fork | +| `next != []` after import | Use `--legacy-terminal-node` with your graph's terminal node | + +## Important Notes + +### Authentication +If your LangGraph deployment uses custom authentication, you may need to temporarily disable it during export. + +### Multi-tenancy +The tool preserves thread metadata including `owner`, so each user will only see their own threads after migration. + +### Rate Limiting +Built-in delays (0.1s) between API calls prevent overload. Failed calls are retried up to 3 times with exponential backoff + jitter. + ## Contributing Contributions are welcome! Please feel free to submit a Pull Request. @@ -247,7 +254,3 @@ Contributions are welcome! Please feel free to submit a Pull Request. ## License MIT License - see [LICENSE](LICENSE) file for details. - -## Acknowledgments - -Built for the [LangGraph](https://github.com/langchain-ai/langgraph) community. diff --git a/langgraph_export/client.py b/langgraph_export/client.py index 8e3721a..6266281 100644 --- a/langgraph_export/client.py +++ b/langgraph_export/client.py @@ -1,68 +1,60 @@ """ -LangGraph Cloud API Client - -Wrapper around the official LangGraph SDK for thread operations. +LangGraph Cloud API Client — with per-page retry and concurrency support. """ +import asyncio +import logging +import random from typing import Any, Dict, List, Optional from langgraph_sdk import get_client +logger = logging.getLogger(__name__) + class LangGraphClient: """Client for interacting with LangGraph Cloud API via the official SDK.""" def __init__(self, base_url: str, api_key: str): - """ - Initialize the LangGraph client. - - Args: - base_url: LangGraph Cloud deployment URL - api_key: LangSmith API key (Service Key recommended) - """ self.base_url = base_url.rstrip("/") self.api_key = api_key self._client = get_client(url=base_url, api_key=api_key) async def close(self) -> None: - """Close the HTTP client.""" - # SDK handles cleanup automatically pass + # ── Retry helper ────────────────────────────────────────────── + + @staticmethod + async def _retry(coro_factory, max_attempts=3, base_delay=0.8, label=""): + """Retry with exponential backoff + jitter. Returns result or raises.""" + last_error = None + for attempt in range(max_attempts): + try: + return await coro_factory() + except Exception as e: + last_error = e + if attempt < max_attempts - 1: + delay = base_delay * (2 ** attempt) * (0.7 + random.random() * 0.6) + logger.warning(f"Retry {attempt+1}/{max_attempts} {label}: {e} ({delay:.1f}s)") + await asyncio.sleep(delay) + raise last_error + + # ── Thread operations ───────────────────────────────────────── + async def search_threads( self, limit: int = 100, offset: int = 0, metadata: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: - """ - Search for threads. - - Args: - limit: Maximum number of threads to return - offset: Number of threads to skip - metadata: Optional metadata filter - - Returns: - List of thread dictionaries - """ kwargs = {"limit": limit, "offset": offset} if metadata: kwargs["metadata"] = metadata - threads = await self._client.threads.search(**kwargs) return list(threads) if threads else [] async def get_thread(self, thread_id: str) -> Dict[str, Any]: - """ - Get thread details. - - Args: - thread_id: Thread ID - - Returns: - Thread dictionary with metadata, values, etc. - """ thread = await self._client.threads.get(thread_id) return dict(thread) if thread else {} @@ -70,74 +62,102 @@ async def get_thread_history( self, thread_id: str, limit: int = 100, + before: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: - """ - Get full thread history (checkpoints). + kwargs: Dict[str, Any] = {"limit": limit} + if before: + kwargs["before"] = before + history = await self._client.threads.get_history(thread_id, **kwargs) + return list(history) if history else [] + + async def get_all_history( + self, + thread_id: str, + limit: Optional[int] = None, + page_size: int = 100, + ) -> List[Dict[str, Any]]: + """Get complete history with per-page retry.""" + all_history: List[Dict[str, Any]] = [] + before = None + tid_short = thread_id[:8] - Args: - thread_id: Thread ID - limit: Maximum number of checkpoints to return + while True: + batch_size = page_size + if limit: + remaining = limit - len(all_history) + batch_size = min(page_size, remaining) - Returns: - List of checkpoint dictionaries - """ - history = await self._client.threads.get_history(thread_id, limit=limit) - return list(history) if history else [] + page_num = len(all_history) // page_size + 1 + + batch = await self._retry( + lambda bs=batch_size, b=before: self.get_thread_history( + thread_id, limit=bs, before=b, + ), + label=f"history({tid_short} p{page_num})", + ) + + if not batch: + break + + all_history.extend(batch) + + if limit and len(all_history) >= limit: + break + if len(batch) < batch_size: + break + + # Build cursor for next page — server expects {"configurable": {"checkpoint_id": ...}} + last = batch[-1] + cp_obj = last.get("checkpoint", {}) + cp_id = cp_obj.get("checkpoint_id") or last.get("checkpoint_id") + if not cp_id: + break + before = {"configurable": {"checkpoint_id": cp_id}} + + return all_history + + # ── Write operations ────────────────────────────────────────── async def create_thread( self, thread_id: str, metadata: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: - """ - Create a new thread with a specific ID. - - Args: - thread_id: Thread ID to use - metadata: Optional metadata to attach - - Returns: - Created thread dictionary - """ thread = await self._client.threads.create( thread_id=thread_id, metadata=metadata or {}, ) return dict(thread) if thread else {} + async def create_thread_with_history( + self, + thread_id: str, + metadata: Optional[Dict[str, Any]] = None, + supersteps: Optional[List[Dict[str, Any]]] = None, + if_exists: Optional[str] = None, + ) -> Dict[str, Any]: + kwargs: Dict[str, Any] = { + "thread_id": thread_id, + "metadata": metadata or {}, + } + if supersteps: + kwargs["supersteps"] = supersteps + if if_exists: + kwargs["if_exists"] = if_exists + thread = await self._client.threads.create(**kwargs) + return dict(thread) if thread else {} + async def update_thread_state( self, thread_id: str, values: Dict[str, Any], as_node: Optional[str] = None, ) -> Dict[str, Any]: - """ - Update thread state. - - Args: - thread_id: Thread ID - values: State values to update - as_node: Optional node name for the update - - Returns: - Update result dictionary - """ result = await self._client.threads.update_state( - thread_id, - values=values, - as_node=as_node, + thread_id, values=values, as_node=as_node, ) return dict(result) if result else {} async def get_thread_state(self, thread_id: str) -> Dict[str, Any]: - """ - Get current thread state. - - Args: - thread_id: Thread ID - - Returns: - Current state dictionary - """ state = await self._client.threads.get_state(thread_id) return dict(state) if state else {} diff --git a/langgraph_export/exporters/json_exporter.py b/langgraph_export/exporters/json_exporter.py index 6c998fe..2887075 100644 --- a/langgraph_export/exporters/json_exporter.py +++ b/langgraph_export/exporters/json_exporter.py @@ -1,79 +1,90 @@ """ JSON Exporter -Export threads to a JSON file. +Export threads to a JSON file with streaming writes. +Threads are written incrementally to avoid holding all data in memory. """ import json from datetime import datetime from pathlib import Path -from typing import Any, Dict, List +from typing import Any, Dict, IO, List, Optional from langgraph_export.exporters.base import BaseExporter, ExportStats, ThreadData +_JSON_INDENT = 2 + class JSONExporter(BaseExporter): """ - Export threads to JSON file. + Export threads to JSON file with streaming writes. - Creates a JSON file with all threads, metadata, and checkpoints. + Writes each thread to disk as it's exported, rather than + buffering everything in memory. The output format is + backward-compatible with the original buffered exporter. """ def __init__( self, source_url: str, output_file: str = "threads_backup.json", - indent: int = 2, ): - """ - Initialize JSON exporter. - - Args: - source_url: Source LangGraph deployment URL - output_file: Path to output JSON file - indent: JSON indentation level (default: 2) - """ super().__init__(source_url) self.output_file = Path(output_file) - self.indent = indent - self._threads: List[Dict[str, Any]] = [] + self._file: Optional[IO[str]] = None + self._thread_count = 0 async def connect(self) -> None: - """No connection needed for JSON export.""" - pass + """Open the file and write the JSON header.""" + self.output_file.parent.mkdir(parents=True, exist_ok=True) + self._file = open(self.output_file, "w", encoding="utf-8") + self._thread_count = 0 + # Start the JSON object with threads array + self._file.write("{\n") + self._file.write(f' "threads": [\n') async def export_thread(self, thread: ThreadData) -> None: - """ - Add thread to export buffer. + """Write a single thread to the file immediately.""" + if not self._file: + raise RuntimeError("Exporter not connected. Call connect() first.") + if self._thread_count > 0: + self._file.write(",\n") + + thread_json = json.dumps( + thread.to_dict(), indent=_JSON_INDENT, ensure_ascii=False, default=str, + ) + # Indent each line of the thread JSON by 4 spaces (inside the array) + indented = "\n".join(f" {line}" for line in thread_json.splitlines()) + self._file.write(indented) - Args: - thread: Thread data to export - """ - self._threads.append(thread.to_dict()) + self._thread_count += 1 self.stats.threads_exported += 1 self.stats.checkpoints_exported += len(thread.history) async def finalize(self) -> None: - """Save all threads to JSON file.""" - export_data = { - "export_date": datetime.now().isoformat(), - "source": self.source_url, - "total_threads": len(self._threads), - "total_checkpoints": self.stats.checkpoints_exported, - "threads": self._threads, - } + """Close the threads array and write metadata footer.""" + if not self._file: + return - # Ensure parent directory exists - self.output_file.parent.mkdir(parents=True, exist_ok=True) + # Close the threads array + if self._thread_count > 0: + self._file.write("\n") + self._file.write(" ],\n") - # Write JSON file - self.output_file.write_text( - json.dumps(export_data, indent=self.indent, ensure_ascii=False, default=str) - ) + # Write metadata at the end (known only after all threads are exported) + self._file.write(f' "export_date": {json.dumps(datetime.now().isoformat())},\n') + self._file.write(f' "source": {json.dumps(self.source_url)},\n') + self._file.write(f' "total_threads": {self.stats.threads_exported},\n') + self._file.write(f' "total_checkpoints": {self.stats.checkpoints_exported}\n') + self._file.write("}\n") + + self._file.flush() async def close(self) -> None: - """Clear internal buffer.""" - self._threads.clear() + """Close the file handle.""" + if self._file: + self._file.close() + self._file = None def get_file_size_mb(self) -> float: """Get output file size in MB.""" @@ -83,40 +94,23 @@ def get_file_size_mb(self) -> float: @staticmethod def load_threads(file_path: str) -> List[ThreadData]: - """ - Load threads from JSON file. - - Args: - file_path: Path to JSON file - - Returns: - List of ThreadData objects - """ + """Load threads from JSON file.""" path = Path(file_path) if not path.exists(): raise FileNotFoundError(f"File not found: {file_path}") - data = json.loads(path.read_text()) + data = json.loads(path.read_text(encoding="utf-8"), strict=False) threads = data.get("threads", []) - return [ThreadData.from_dict(t) for t in threads] @staticmethod def get_export_info(file_path: str) -> Dict[str, Any]: - """ - Get export metadata from JSON file. - - Args: - file_path: Path to JSON file - - Returns: - Dictionary with export date, source, and counts - """ + """Get export metadata from JSON file.""" path = Path(file_path) if not path.exists(): raise FileNotFoundError(f"File not found: {file_path}") - data = json.loads(path.read_text()) + data = json.loads(path.read_text(encoding="utf-8"), strict=False) return { "export_date": data.get("export_date"), "source": data.get("source"), diff --git a/langgraph_export/migrator.py b/langgraph_export/migrator.py index 85f231a..8bf0921 100644 --- a/langgraph_export/migrator.py +++ b/langgraph_export/migrator.py @@ -1,10 +1,12 @@ """ -Thread Migrator - -Main class for orchestrating thread export and migration operations. +Thread Migrator — concurrent fetch, streaming export, per-page retry. """ import asyncio +import logging +import random +import time +from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional from langgraph_sdk.errors import ConflictError, APIStatusError @@ -14,58 +16,70 @@ from langgraph_export.exporters.json_exporter import JSONExporter from langgraph_export.exporters.postgres_exporter import PostgresExporter +logger = logging.getLogger(__name__) + +CONCURRENCY = 5 # max parallel thread fetches + + +@dataclass +class MigrationProgress: + """Mutable progress state shared across concurrent tasks.""" + exported: int = 0 + skipped: int = 0 + failed: int = 0 + total_checkpoints: int = 0 + total_threads: int = 0 # set after discovery + start_time: float = field(default_factory=time.time) + errors: List[str] = field(default_factory=list) + + @property + def elapsed(self) -> float: + return time.time() - self.start_time + + @property + def rate(self) -> float: + done = self.exported + self.skipped + self.failed + return done / self.elapsed if self.elapsed > 0 else 0 -class ThreadMigrator: - """ - Thread migration and export manager. - Handles fetching threads from source, exporting to various - destinations, and optionally importing to a target deployment. - """ +class ThreadMigrator: + """Thread migration and export manager with concurrent fetching.""" def __init__( self, source_url: Optional[str] = None, target_url: Optional[str] = None, api_key: str = "", - rate_limit_delay: float = 0.2, + source_api_key: Optional[str] = None, + target_api_key: Optional[str] = None, + rate_limit_delay: float = 0.1, + concurrency: int = CONCURRENCY, ): - """ - Initialize the migrator. - - Args: - source_url: Source LangGraph deployment URL - target_url: Target LangGraph deployment URL (for migration) - api_key: LangSmith API key - rate_limit_delay: Delay between API calls (seconds) - """ self.source_url = source_url self.target_url = target_url - self.api_key = api_key + self.source_api_key = source_api_key or api_key + self.target_api_key = target_api_key or api_key self.rate_limit_delay = rate_limit_delay + self.concurrency = concurrency self._source_client: Optional[LangGraphClient] = None self._target_client: Optional[LangGraphClient] = None self._exporters: List[BaseExporter] = [] async def __aenter__(self) -> "ThreadMigrator": - """Async context manager entry.""" await self.connect() return self async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: - """Async context manager exit.""" await self.close() async def connect(self) -> None: - """Initialize clients.""" if self.source_url: - self._source_client = LangGraphClient(self.source_url, self.api_key) + self._source_client = LangGraphClient(self.source_url, self.source_api_key) if self.target_url: - self._target_client = LangGraphClient(self.target_url, self.api_key) + self._target_client = LangGraphClient(self.target_url, self.target_api_key) async def close(self) -> None: - """Close all clients and exporters.""" if self._source_client: await self._source_client.close() if self._target_client: @@ -73,248 +87,434 @@ async def close(self) -> None: for exporter in self._exporters: await exporter.close() - def add_json_exporter(self, output_file: str = "threads_backup.json") -> JSONExporter: - """ - Add JSON file exporter. - - Args: - output_file: Path to output JSON file + @staticmethod + async def _retry(coro_factory, max_attempts=3, base_delay=0.8, label=""): + last_error = None + for attempt in range(max_attempts): + try: + return await coro_factory() + except Exception as e: + last_error = e + if attempt < max_attempts - 1: + delay = base_delay * (2 ** attempt) * (0.7 + random.random() * 0.6) + logger.warning(f"Retry {attempt+1}/{max_attempts} {label}: {e} ({delay:.1f}s)") + await asyncio.sleep(delay) + raise last_error - Returns: - The created exporter - """ + def add_json_exporter(self, output_file: str = "threads_backup.json") -> JSONExporter: exporter = JSONExporter(self.source_url or "", output_file) self._exporters.append(exporter) return exporter def add_postgres_exporter(self, database_url: str) -> PostgresExporter: - """ - Add PostgreSQL exporter. - - Args: - database_url: PostgreSQL connection URL - - Returns: - The created exporter - """ exporter = PostgresExporter(self.source_url or "", database_url) self._exporters.append(exporter) return exporter - async def fetch_all_threads( + # ── Discovery: list all thread IDs ──────────────────────────── + + async def _discover_thread_ids( self, limit: Optional[int] = None, - progress_callback: Optional[Callable[[int, str], None]] = None, - ) -> List[ThreadData]: - """ - Fetch all threads from source deployment. - - Args: - limit: Maximum number of threads to fetch (None = all) - progress_callback: Optional callback(count, message) for progress updates - - Returns: - List of ThreadData objects - """ + metadata_filter: Optional[Dict[str, Any]] = None, + ) -> List[Dict[str, Any]]: + """Fetch all thread summaries (ID + metadata only, no history).""" if not self._source_client: raise RuntimeError("Source URL not configured") - all_threads: List[ThreadData] = [] + all_summaries: List[Dict[str, Any]] = [] offset = 0 batch_size = 100 if limit is None else min(limit, 100) - # Fetch thread list while True: - if progress_callback: - progress_callback(len(all_threads), f"Fetching threads (offset={offset})...") - - threads = await self._source_client.search_threads( - limit=batch_size, - offset=offset, + batch = await self._retry( + lambda o=offset: self._source_client.search_threads( + limit=batch_size, offset=o, metadata=metadata_filter, + ), + label="search_threads", ) + if not batch: + break + + all_summaries.extend(batch) - if not threads: + if limit and len(all_summaries) >= limit: + all_summaries = all_summaries[:limit] + break + if len(batch) < batch_size: break - for thread_summary in threads: - thread_id = thread_summary.get("thread_id") - if not thread_id: - continue + offset += len(batch) + await asyncio.sleep(self.rate_limit_delay) - try: - # Get full thread details - details = await self._source_client.get_thread(thread_id) - history = await self._source_client.get_thread_history(thread_id) - - thread_data = ThreadData( - thread_id=thread_id, - metadata=details.get("metadata", {}), - values=details.get("values", {}), - history=history, - created_at=details.get("created_at"), - updated_at=details.get("updated_at"), - ) - all_threads.append(thread_data) + return all_summaries - if progress_callback: - progress_callback( - len(all_threads), - f"Fetched thread {thread_id[:8]}... ({len(history)} checkpoints)" - ) + # ── Fetch single thread (details + history) ─────────────────── - except Exception as e: - if progress_callback: - progress_callback(len(all_threads), f"Error: {thread_id}: {e}") - continue + async def _fetch_thread( + self, + thread_id: str, + history_limit: Optional[int] = None, + ) -> ThreadData: + """Fetch one thread's details + full history.""" + details = await self._retry( + lambda: self._source_client.get_thread(thread_id), + label=f"get({thread_id[:8]})", + ) + history = await self._source_client.get_all_history( + thread_id, limit=history_limit, + ) + return ThreadData( + thread_id=thread_id, + metadata=details.get("metadata", {}), + values=details.get("values", {}), + history=history, + created_at=details.get("created_at"), + updated_at=details.get("updated_at"), + ) - await asyncio.sleep(self.rate_limit_delay) + # ── Fetch all (in-memory, for --full mode) ──────────────────── - # Check limit - if limit and len(all_threads) >= limit: - return all_threads + async def fetch_all_threads( + self, + limit: Optional[int] = None, + metadata_filter: Optional[Dict[str, Any]] = None, + history_limit: Optional[int] = None, + progress_callback: Optional[Callable[[int, str], None]] = None, + ) -> List[ThreadData]: + """Fetch all threads concurrently.""" + summaries = await self._discover_thread_ids(limit, metadata_filter) - offset += len(threads) + if progress_callback: + progress_callback(0, f"Discovered {len(summaries)} threads, fetching...") - # No more pages - if len(threads) < batch_size: - break + sem = asyncio.Semaphore(self.concurrency) + results: List[Optional[ThreadData]] = [None] * len(summaries) + done_count = 0 - await asyncio.sleep(self.rate_limit_delay) + async def fetch_one(idx: int, thread_id: str): + nonlocal done_count + async with sem: + try: + results[idx] = await self._fetch_thread(thread_id, history_limit) + done_count += 1 + if progress_callback: + cp = len(results[idx].history) if results[idx] else 0 + progress_callback(done_count, f"Fetched {thread_id[:8]}... ({cp} cp)") + except Exception as e: + done_count += 1 + logger.warning(f"Skip {thread_id[:8]}: {e}") + if progress_callback: + progress_callback(done_count, f"Skip {thread_id[:8]}: {e}") - return all_threads + tasks = [ + fetch_one(i, s.get("thread_id")) + for i, s in enumerate(summaries) + if s.get("thread_id") + ] + await asyncio.gather(*tasks) + + return [t for t in results if t is not None] + + # ── Streaming export (fetch + write one at a time) ──────────── async def export_threads( self, threads: Optional[List[ThreadData]] = None, limit: Optional[int] = None, + metadata_filter: Optional[Dict[str, Any]] = None, + history_limit: Optional[int] = None, progress_callback: Optional[Callable[[int, str], None]] = None, ) -> ExportStats: """ Export threads to all configured exporters. + Streams fetch+write concurrently to minimize memory. + """ + for exporter in self._exporters: + await exporter.connect() - Args: - threads: Pre-fetched threads (if None, fetches from source) - limit: Maximum number of threads to export - progress_callback: Optional callback for progress updates + if threads is not None: + return await self._export_prefetched(threads, progress_callback) - Returns: - Combined export statistics - """ - # Fetch threads if not provided - if threads is None: - threads = await self.fetch_all_threads(limit, progress_callback) + # Streaming: discover → concurrent fetch → sequential write + if not self._source_client: + raise RuntimeError("Source URL not configured") + + summaries = await self._discover_thread_ids(limit, metadata_filter) + total = len(summaries) + + if progress_callback: + progress_callback(0, f"Found {total} threads — exporting...") + + sem = asyncio.Semaphore(self.concurrency) + # Use a queue: fetchers produce, writer consumes + queue: asyncio.Queue[Optional[ThreadData]] = asyncio.Queue(maxsize=self.concurrency * 2) + progress = MigrationProgress(total_threads=total) + + async def fetch_worker(thread_id: str): + async with sem: + try: + td = await self._fetch_thread(thread_id, history_limit) + await queue.put(td) + except Exception as e: + progress.failed += 1 + progress.errors.append(f"{thread_id[:8]}: {e}") + logger.warning(f"Fetch failed {thread_id[:8]}: {e}") + + async def writer(): + while True: + td = await queue.get() + if td is None: # sentinel + break + for exporter in self._exporters: + await exporter.export_thread(td) + progress.exported += 1 + progress.total_checkpoints += len(td.history) + if progress_callback: + progress_callback( + progress.exported, + f"{td.thread_id[:8]} ({len(td.history)} cp) " + f"[{progress.exported}/{total}]" + ) + + # Start writer task + writer_task = asyncio.create_task(writer()) + + # Launch all fetchers + thread_ids = [s.get("thread_id") for s in summaries if s.get("thread_id")] + fetch_tasks = [asyncio.create_task(fetch_worker(tid)) for tid in thread_ids] + await asyncio.gather(*fetch_tasks) + + # Signal writer to stop + await queue.put(None) + await writer_task - # Connect all exporters for exporter in self._exporters: - await exporter.connect() + await exporter.finalize() - # Export each thread + return ExportStats( + threads_exported=progress.exported, + checkpoints_exported=progress.total_checkpoints, + ) + + async def _export_prefetched( + self, + threads: List[ThreadData], + progress_callback: Optional[Callable[[int, str], None]] = None, + ) -> ExportStats: total_checkpoints = 0 for i, thread in enumerate(threads): for exporter in self._exporters: await exporter.export_thread(thread) total_checkpoints += len(thread.history) - if progress_callback: progress_callback(i + 1, f"Exported {thread.thread_id[:8]}...") - # Finalize all exporters for exporter in self._exporters: await exporter.finalize() - # Return combined stats - stats = ExportStats( + return ExportStats( threads_exported=len(threads), checkpoints_exported=total_checkpoints, ) - return stats + + # ── Supersteps computation ──────────────────────────────────── + + @staticmethod + def _compute_values_delta( + prev_values: Dict[str, Any], + curr_values: Dict[str, Any], + ) -> Dict[str, Any]: + delta: Dict[str, Any] = {} + for key, curr_val in curr_values.items(): + prev_val = prev_values.get(key) + if curr_val == prev_val: + continue + + if key == "messages": + prev_ids = {m.get("id") for m in (prev_val or []) if isinstance(m, dict)} + new_msgs = [m for m in (curr_val or []) if isinstance(m, dict) and m.get("id") not in prev_ids] + if new_msgs: + delta["messages"] = new_msgs + elif ( + isinstance(curr_val, list) + and curr_val + and isinstance(curr_val[0], dict) + and "id" in curr_val[0] + ): + prev_ids = {item.get("id") for item in (prev_val or []) if isinstance(item, dict)} + new_items = [item for item in curr_val if isinstance(item, dict) and item.get("id") not in prev_ids] + if new_items: + delta[key] = new_items + else: + delta[key] = curr_val + return delta + + @staticmethod + def _history_to_supersteps(history: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + if not history: + return [] + + chronological = list(reversed(history)) + supersteps: List[Dict[str, Any]] = [] + + has_writes = any( + isinstance((h.get("metadata") or {}).get("writes"), dict) for h in history + ) + + if has_writes: + for state in chronological: + metadata = state.get("metadata") or {} + writes = metadata.get("writes") + if not writes or not isinstance(writes, dict): + continue + updates = [] + for node_name, node_values in writes.items(): + if node_values is None: + continue + updates.append({"values": node_values, "as_node": node_name}) + if updates: + supersteps.append({"updates": updates}) + else: + prev_values: Dict[str, Any] = {} + for i, state in enumerate(chronological): + curr_values = state.get("values") or {} + if i == 0: + if curr_values: + supersteps.append({ + "updates": [{"values": curr_values, "as_node": "__start__"}], + }) + prev_values = curr_values + continue + + prev_next = chronological[i - 1].get("next", []) + as_node = prev_next[0] if prev_next else "__unknown__" + delta = ThreadMigrator._compute_values_delta(prev_values, curr_values) + if delta: + supersteps.append({ + "updates": [{"values": delta, "as_node": as_node}], + }) + prev_values = curr_values + + return supersteps + + # ── Import ──────────────────────────────────────────────────── + + async def _import_thread_with_history(self, thread: ThreadData) -> int: + supersteps = self._history_to_supersteps(thread.history) + await self._target_client.create_thread_with_history( + thread_id=thread.thread_id, + metadata=thread.metadata, + supersteps=supersteps if supersteps else None, + ) + return len(supersteps) + + async def _import_thread_legacy( + self, thread: ThreadData, as_node: Optional[str] = None, + ) -> None: + await self._target_client.create_thread( + thread_id=thread.thread_id, + metadata=thread.metadata, + ) + if thread.values: + await self._target_client.update_thread_state( + thread_id=thread.thread_id, + values=thread.values, + as_node=as_node, + ) async def import_threads( self, threads: List[ThreadData], dry_run: bool = False, + legacy_terminal_node: Optional[str] = None, progress_callback: Optional[Callable[[int, str], None]] = None, ) -> Dict[str, int]: - """ - Import threads to target deployment. - - Args: - threads: Threads to import - dry_run: If True, don't actually create threads - progress_callback: Optional callback for progress updates - - Returns: - Dictionary with created, skipped, failed counts - """ if not self._target_client: raise RuntimeError("Target URL not configured") - results = {"created": 0, "skipped": 0, "failed": 0} + results = {"created": 0, "skipped": 0, "failed": 0, "checkpoints": 0} + use_legacy = False for i, thread in enumerate(threads): try: if dry_run: + supersteps = self._history_to_supersteps(thread.history) results["skipped"] += 1 if progress_callback: - progress_callback(i + 1, f"[DRY-RUN] Would create {thread.thread_id[:8]}...") + progress_callback( + i + 1, + f"[DRY-RUN] {thread.thread_id[:8]}... ({len(supersteps)} supersteps)" + ) continue - # Create thread - await self._target_client.create_thread( - thread_id=thread.thread_id, - metadata=thread.metadata, - ) - - # Update state if available - if thread.values: - await self._target_client.update_thread_state( - thread_id=thread.thread_id, - values=thread.values, - ) + if not use_legacy: + try: + checkpoints = await self._import_thread_with_history(thread) + results["checkpoints"] += checkpoints + except APIStatusError as e: + if e.status_code in (400, 422): + use_legacy = True + if progress_callback: + progress_callback(i + 1, "supersteps unsupported → legacy mode") + await self._import_thread_legacy(thread, as_node=legacy_terminal_node) + else: + raise + else: + await self._import_thread_legacy(thread, as_node=legacy_terminal_node) results["created"] += 1 if progress_callback: - progress_callback(i + 1, f"Created {thread.thread_id[:8]}...") + cp_count = len(thread.history) + mode = "legacy" if use_legacy else f"{cp_count} cp" + progress_callback(i + 1, f"Created {thread.thread_id[:8]}... ({mode})") except ConflictError: results["skipped"] += 1 if progress_callback: - progress_callback(i + 1, f"Skipped (exists) {thread.thread_id[:8]}...") - + progress_callback(i + 1, f"Skip (exists) {thread.thread_id[:8]}...") except APIStatusError as e: results["failed"] += 1 if progress_callback: - progress_callback(i + 1, f"Failed {thread.thread_id[:8]}: {e}") - + progress_callback(i + 1, f"FAIL {thread.thread_id[:8]}: {e}") except Exception as e: results["failed"] += 1 if progress_callback: - progress_callback(i + 1, f"Error {thread.thread_id[:8]}: {e}") + progress_callback(i + 1, f"ERROR {thread.thread_id[:8]}: {e}") await asyncio.sleep(self.rate_limit_delay) return results - async def validate_migration(self) -> Dict[str, int]: - """ - Compare thread counts between source and target. + # ── Validation ──────────────────────────────────────────────── - Returns: - Dictionary with source_count, target_count - """ + async def validate_migration( + self, + check_history: bool = False, + sample_thread_id: Optional[str] = None, + ) -> Dict[str, Any]: source_count = 0 target_count = 0 if self._source_client: threads = await self._source_client.search_threads(limit=10000) source_count = len(threads) - if self._target_client: threads = await self._target_client.search_threads(limit=10000) target_count = len(threads) - return { + result: Dict[str, Any] = { "source_count": source_count, "target_count": target_count, "difference": source_count - target_count, } + + if check_history and sample_thread_id: + if self._source_client: + src_history = await self._source_client.get_all_history(sample_thread_id) + result["history_source"] = len(src_history) + if self._target_client: + tgt_history = await self._target_client.get_all_history(sample_thread_id) + result["history_target"] = len(tgt_history) + + return result diff --git a/migrate_threads.py b/migrate_threads.py index 45fd63e..23d9cca 100644 --- a/migrate_threads.py +++ b/migrate_threads.py @@ -1,326 +1,357 @@ #!/usr/bin/env python3 """ -LangGraph Threads Export Tool +LangGraph Threads Migration Tool -Export threads, checkpoints, and conversation history from LangGraph Cloud. -Save to JSON files, PostgreSQL databases, or migrate to another deployment. - -Usage: - python migrate_threads.py --source-url --export-json threads.json - python migrate_threads.py --source-url --export-postgres - python migrate_threads.py --source-url --target-url --full +Export/import threads with full checkpoint history between LangGraph Cloud deployments. +Features: concurrent fetching, streaming JSON export, per-page retry, rich progress. """ import argparse import asyncio +import json import os import sys +from typing import Optional from dotenv import load_dotenv from rich.console import Console from rich.panel import Panel -from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn +from rich.progress import ( + Progress, SpinnerColumn, TextColumn, BarColumn, + TimeElapsedColumn, MofNCompleteColumn, +) from rich.table import Table from langgraph_export import ThreadMigrator from langgraph_export.exporters import JSONExporter -# Load environment variables load_dotenv() - console = Console() +def make_progress(**kwargs) -> Progress: + """Create a rich Progress bar with consistent columns.""" + return Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(bar_width=30), + MofNCompleteColumn(), + TextColumn("•"), + TimeElapsedColumn(), + TextColumn("•"), + TextColumn("{task.fields[detail]}"), + console=console, + **kwargs, + ) + + +# ── Export to JSON ──────────────────────────────────────────────── + async def run_export_json( source_url: str, - api_key: str, + source_api_key: str, output_file: str, test_single: bool = False, + metadata_filter: Optional[dict] = None, + history_limit: Optional[int] = None, ) -> None: - """Export threads to JSON file.""" console.print(Panel.fit( - "[bold cyan]Exporting threads to JSON[/bold cyan]", + "[bold cyan]Export threads to JSON[/bold cyan]", border_style="cyan", )) + if metadata_filter: + console.print(f" [dim]Filter:[/dim] {metadata_filter}") + console.print(f" [dim]Output:[/dim] {output_file}") + console.print(f" [dim]Concurrency:[/dim] 5 parallel fetches") + console.print() - async with ThreadMigrator(source_url=source_url, api_key=api_key) as migrator: + async with ThreadMigrator(source_url=source_url, source_api_key=source_api_key) as migrator: json_exporter = migrator.add_json_exporter(output_file) - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - console=console, - ) as progress: - task = progress.add_task("[cyan]Exporting...", total=None) - - def update_progress(count: int, message: str) -> None: - progress.update(task, completed=count, description=f"[cyan]{message}") + with make_progress() as progress: + # Phase 1: discover + discover_task = progress.add_task( + "[cyan]Discovering threads...", total=None, detail="searching..." + ) limit = 1 if test_single else None - stats = await migrator.export_threads(limit=limit, progress_callback=update_progress) + summaries = await migrator._discover_thread_ids(limit, metadata_filter) + total = len(summaries) + progress.update(discover_task, completed=total, total=total, detail=f"{total} found") + + # Phase 2: fetch + export (streaming) + export_task = progress.add_task( + "[green]Fetching & writing", total=total, detail="starting..." + ) + + exported = 0 + + def on_progress(count: int, message: str) -> None: + nonlocal exported + exported = count + progress.update(export_task, completed=count, detail=message) + + stats = await migrator.export_threads( + limit=limit, + metadata_filter=metadata_filter, + history_limit=history_limit, + progress_callback=on_progress, + ) + + # Summary + console.print() + table = Table(title="Export Complete", show_header=False, border_style="green") + table.add_column("Metric", style="cyan") + table.add_column("Value", justify="right", style="bold") + table.add_row("Threads", str(stats.threads_exported)) + table.add_row("Checkpoints", str(stats.checkpoints_exported)) + table.add_row("File", output_file) + table.add_row("Size", f"{json_exporter.get_file_size_mb():.2f} MB") + console.print(table) - # Display results - console.print(f"\n[green]✓[/green] Threads exported: {stats.threads_exported}") - console.print(f"[green]✓[/green] Checkpoints exported: {stats.checkpoints_exported}") - console.print(f"[green]✓[/green] Output file: [bold]{output_file}[/bold]") - console.print(f"[green]✓[/green] File size: {json_exporter.get_file_size_mb():.2f} MB") +# ── Export to PostgreSQL ────────────────────────────────────────── async def run_export_postgres( source_url: str, - api_key: str, + source_api_key: str, database_url: str, output_file: str, test_single: bool = False, + metadata_filter: Optional[dict] = None, + history_limit: Optional[int] = None, ) -> None: - """Export threads to PostgreSQL (and JSON backup).""" console.print(Panel.fit( - "[bold cyan]Exporting threads to PostgreSQL[/bold cyan]", + "[bold cyan]Export threads to PostgreSQL + JSON[/bold cyan]", border_style="cyan", )) - async with ThreadMigrator(source_url=source_url, api_key=api_key) as migrator: + async with ThreadMigrator(source_url=source_url, source_api_key=source_api_key) as migrator: json_exporter = migrator.add_json_exporter(output_file) pg_exporter = migrator.add_postgres_exporter(database_url) - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - console=console, - ) as progress: - task = progress.add_task("[cyan]Exporting...", total=None) + with make_progress() as progress: + task = progress.add_task("[cyan]Exporting...", total=None, detail="starting...") - def update_progress(count: int, message: str) -> None: - progress.update(task, completed=count, description=f"[cyan]{message}") + def on_progress(count: int, message: str) -> None: + progress.update(task, completed=count, detail=message) limit = 1 if test_single else None - stats = await migrator.export_threads(limit=limit, progress_callback=update_progress) + stats = await migrator.export_threads( + limit=limit, + metadata_filter=metadata_filter, + history_limit=history_limit, + progress_callback=on_progress, + ) - # Get PostgreSQL stats db_stats = await pg_exporter.get_database_stats() - - # Display results - console.print(f"\n[green]✓[/green] Threads exported: {stats.threads_exported}") - console.print(f"[green]✓[/green] Checkpoints exported: {stats.checkpoints_exported}") - console.print(f"[green]✓[/green] JSON backup: [bold]{output_file}[/bold] ({json_exporter.get_file_size_mb():.2f} MB)") + console.print(f"\n[green]✓[/green] Threads: {stats.threads_exported}") + console.print(f"[green]✓[/green] JSON: {output_file} ({json_exporter.get_file_size_mb():.2f} MB)") console.print(f"[green]✓[/green] PostgreSQL: {db_stats['threads']} threads, {db_stats['checkpoints']} checkpoints") +# ── Import from JSON ────────────────────────────────────────────── + async def run_import_json( target_url: str, - api_key: str, + target_api_key: str, input_file: str, dry_run: bool = False, + legacy_terminal_node: Optional[str] = None, ) -> None: - """Import threads from JSON file.""" console.print(Panel.fit( - "[bold green]Importing threads from JSON[/bold green]", + "[bold green]Import threads from JSON[/bold green]", border_style="green", )) - # Load threads from JSON threads = JSONExporter.load_threads(input_file) info = JSONExporter.get_export_info(input_file) + total = len(threads) - console.print(f"[cyan]Source:[/cyan] {info['source']}") - console.print(f"[cyan]Export date:[/cyan] {info['export_date']}") - console.print(f"[cyan]Threads to import:[/cyan] {len(threads)}") - + console.print(f" [dim]Source:[/dim] {info['source']}") + console.print(f" [dim]Date:[/dim] {info['export_date']}") + console.print(f" [dim]Threads:[/dim] {total}") if dry_run: - console.print("\n[yellow]⚠ DRY-RUN MODE: No changes will be made[/yellow]") + console.print(" [yellow]DRY-RUN — no changes[/yellow]") + console.print() - async with ThreadMigrator(target_url=target_url, api_key=api_key) as migrator: - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - console=console, - ) as progress: - task = progress.add_task("[green]Importing...", total=len(threads)) + async with ThreadMigrator(target_url=target_url, target_api_key=target_api_key) as migrator: + with make_progress() as progress: + task = progress.add_task("[green]Importing...", total=total, detail="starting...") - def update_progress(count: int, message: str) -> None: - progress.update(task, completed=count, description=f"[green]{message}") + def on_progress(count: int, message: str) -> None: + progress.update(task, completed=count, detail=message) results = await migrator.import_threads( - threads, - dry_run=dry_run, - progress_callback=update_progress, + threads, dry_run=dry_run, + legacy_terminal_node=legacy_terminal_node, + progress_callback=on_progress, ) - # Display results - table = Table(title="Import Summary") - table.add_column("Status", style="cyan") - table.add_column("Count", justify="right", style="magenta") - - table.add_row("Created", str(results["created"])) - table.add_row("Skipped (exists)", str(results["skipped"])) - table.add_row("Failed", str(results["failed"])) - table.add_row("Total", str(len(threads))) + console.print() + table = Table(title="Import Summary", border_style="green") + table.add_column("Status", style="cyan") + table.add_column("Count", justify="right", style="bold") + table.add_row("Created", str(results["created"])) + table.add_row("Skipped (exists)", str(results["skipped"])) + table.add_row("Failed", str(results["failed"])) + table.add_row("Checkpoints (supersteps)", str(results.get("checkpoints", 0))) + console.print(table) - console.print(table) +# ── Full migration ──────────────────────────────────────────────── async def run_full_migration( source_url: str, target_url: str, - api_key: str, + source_api_key: str, + target_api_key: str, backup_file: str, dry_run: bool = False, test_single: bool = False, + metadata_filter: Optional[dict] = None, + history_limit: Optional[int] = None, + legacy_terminal_node: Optional[str] = None, ) -> None: - """Full migration: export + import + validate.""" # Phase 1: Export console.print(Panel.fit( - "[bold cyan]Phase 1: Export threads from source[/bold cyan]", + "[bold cyan]Phase 1 — Export from source[/bold cyan]", border_style="cyan", )) + if metadata_filter: + console.print(f" [dim]Filter:[/dim] {metadata_filter}") async with ThreadMigrator( source_url=source_url, target_url=target_url, - api_key=api_key, + source_api_key=source_api_key, + target_api_key=target_api_key, ) as migrator: json_exporter = migrator.add_json_exporter(backup_file) - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - console=console, - ) as progress: - task = progress.add_task("[cyan]Exporting...", total=None) - - def update_progress(count: int, message: str) -> None: - progress.update(task, completed=count, description=f"[cyan]{message}") + with make_progress() as progress: + task = progress.add_task("[cyan]Exporting...", total=None, detail="discovering...") limit = 1 if test_single else None - threads = await migrator.fetch_all_threads(limit=limit, progress_callback=update_progress) + threads = await migrator.fetch_all_threads( + limit=limit, + metadata_filter=metadata_filter, + history_limit=history_limit, + progress_callback=lambda c, m: progress.update(task, completed=c, detail=m), + ) - # Export to JSON + await json_exporter.connect() for thread in threads: await json_exporter.export_thread(thread) await json_exporter.finalize() - console.print(f"\n[green]✓[/green] Exported {len(threads)} threads") - console.print(f"[green]✓[/green] Backup: {backup_file} ({json_exporter.get_file_size_mb():.2f} MB)") + console.print(f" [green]✓[/green] {len(threads)} threads → {backup_file} ({json_exporter.get_file_size_mb():.2f} MB)") # Phase 2: Import console.print(Panel.fit( - f"[bold green]Phase 2: Import {len(threads)} threads to target[/bold green]", + f"[bold green]Phase 2 — Import {len(threads)} threads to target[/bold green]", border_style="green", )) - if dry_run: - console.print("[yellow]⚠ DRY-RUN MODE: No changes will be made[/yellow]") + console.print(" [yellow]DRY-RUN — no changes[/yellow]") - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - console=console, - ) as progress: - task = progress.add_task("[green]Importing...", total=len(threads)) - - def update_progress2(count: int, message: str) -> None: - progress.update(task, completed=count, description=f"[green]{message}") + with make_progress() as progress: + task = progress.add_task("[green]Importing...", total=len(threads), detail="starting...") results = await migrator.import_threads( threads, dry_run=dry_run, - progress_callback=update_progress2, + legacy_terminal_node=legacy_terminal_node, + progress_callback=lambda c, m: progress.update(task, completed=c, detail=m), ) - # Display import results - table = Table(title="Import Summary") + table = Table(title="Import Summary", border_style="green") table.add_column("Status", style="cyan") - table.add_column("Count", justify="right", style="magenta") - + table.add_column("Count", justify="right", style="bold") table.add_row("Created", str(results["created"])) - table.add_row("Skipped (exists)", str(results["skipped"])) + table.add_row("Skipped", str(results["skipped"])) table.add_row("Failed", str(results["failed"])) - + table.add_row("Checkpoints", str(results.get("checkpoints", 0))) console.print(table) # Phase 3: Validate console.print(Panel.fit( - "[bold blue]Phase 3: Validate migration[/bold blue]", + "[bold blue]Phase 3 — Validate[/bold blue]", border_style="blue", )) - - validation = await migrator.validate_migration() - - table = Table(title="Validation") - table.add_column("Deployment", style="cyan") - table.add_column("Threads", justify="right", style="magenta") - - table.add_row("Source", str(validation["source_count"])) - table.add_row("Target", str(validation["target_count"])) - + sample_id = threads[0].thread_id if test_single and threads else None + validation = await migrator.validate_migration( + check_history=bool(sample_id), + sample_thread_id=sample_id, + ) + + table = Table(title="Validation", border_style="blue") + table.add_column("Metric", style="cyan") + table.add_column("Source", justify="right") + table.add_column("Target", justify="right") + table.add_row("Threads", str(validation["source_count"]), str(validation["target_count"])) + if "history_source" in validation: + table.add_row( + f"History ({sample_id[:8]}...)", + str(validation["history_source"]), + str(validation["history_target"]), + ) console.print(table) if validation["difference"] <= 0: - console.print("[green]✓ Migration validated successfully![/green]") + console.print("[green]✓ Thread count validated[/green]") else: - console.print(f"[yellow]⚠ Target has {validation['difference']} fewer threads[/yellow]") + console.print(f"[yellow]⚠ {validation['difference']} threads missing on target[/yellow]") + +# ── Validate only ───────────────────────────────────────────────── async def run_validate( source_url: str, target_url: str, - api_key: str, + source_api_key: str, + target_api_key: str, ) -> None: - """Validate migration by comparing thread counts.""" - console.print(Panel.fit( - "[bold blue]Validating migration[/bold blue]", - border_style="blue", - )) + console.print(Panel.fit("[bold blue]Validate migration[/bold blue]", border_style="blue")) async with ThreadMigrator( source_url=source_url, target_url=target_url, - api_key=api_key, + source_api_key=source_api_key, + target_api_key=target_api_key, ) as migrator: validation = await migrator.validate_migration() - table = Table(title="Source vs Target") - table.add_column("Deployment", style="cyan") - table.add_column("Threads", justify="right", style="magenta") - table.add_column("Status", style="green") + table = Table(title="Source vs Target", border_style="blue") + table.add_column("Deployment", style="cyan") + table.add_column("Threads", justify="right", style="bold") + table.add_row("Source", str(validation["source_count"])) + table.add_row("Target", str(validation["target_count"])) + console.print(table) - table.add_row("Source", str(validation["source_count"]), "✓") - status = "✓" if validation["difference"] <= 0 else "⚠" - table.add_row("Target", str(validation["target_count"]), status) - - console.print(table) + if validation["difference"] <= 0: + console.print("[green]✓ Migration validated[/green]") + else: + console.print(f"[yellow]⚠ {validation['difference']} threads missing on target[/yellow]") - if validation["difference"] <= 0: - console.print("[green]✓ Migration validated successfully![/green]") - else: - console.print(f"[yellow]⚠ Target has {validation['difference']} fewer threads[/yellow]") +# ── CLI ─────────────────────────────────────────────────────────── def main() -> None: - """Main entry point.""" parser = argparse.ArgumentParser( - description="LangGraph Threads Export Tool", + description="LangGraph Threads Migration Tool", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Export to JSON python migrate_threads.py --source-url https://my.langgraph.app --export-json backup.json - # Export to PostgreSQL - python migrate_threads.py --source-url https://my.langgraph.app --export-postgres + # Full migration (same org) + python migrate_threads.py --source-url https://src.langgraph.app --target-url https://tgt.langgraph.app --full - # Full migration - python migrate_threads.py --source-url https://prod.langgraph.app --target-url https://dev.langgraph.app --full + # Full migration (cross-org) + python migrate_threads.py --source-url https://org1.langgraph.app --target-url https://org2.langgraph.app \\ + --source-api-key lsv2_sk_... --target-api-key lsv2_sk_... --full # Import from JSON python migrate_threads.py --target-url https://dev.langgraph.app --import-json backup.json @@ -329,8 +360,10 @@ def main() -> None: parser.add_argument("--source-url", help="Source LangGraph deployment URL") parser.add_argument("--target-url", help="Target LangGraph deployment URL") - parser.add_argument("--api-key", help="LangSmith API key (or set LANGSMITH_API_KEY)") - parser.add_argument("--database-url", help="PostgreSQL URL (or set DATABASE_URL)") + parser.add_argument("--api-key", help="Shared API key (or LANGSMITH_API_KEY)") + parser.add_argument("--source-api-key", help="Source API key (or LANGSMITH_SOURCE_API_KEY)") + parser.add_argument("--target-api-key", help="Target API key (or LANGSMITH_TARGET_API_KEY)") + parser.add_argument("--database-url", help="PostgreSQL URL (or DATABASE_URL)") parser.add_argument("--export-json", metavar="FILE", help="Export to JSON file") parser.add_argument("--export-postgres", action="store_true", help="Export to PostgreSQL") @@ -341,82 +374,115 @@ def main() -> None: parser.add_argument("--backup-file", default="threads_backup.json", help="Backup file path") parser.add_argument("--dry-run", action="store_true", help="Simulation mode") parser.add_argument("--test-single", action="store_true", help="Test with single thread") + parser.add_argument("--metadata-filter", metavar="JSON", + help="Filter threads by metadata (JSON, e.g. '{\"workspace_id\": 4}')") + parser.add_argument("--history-limit", type=int, default=None, + help="Max checkpoints per thread (default: all)") + parser.add_argument("--concurrency", type=int, default=5, + help="Parallel thread fetches (default: 5)") + parser.add_argument("--legacy-terminal-node", metavar="NODE", + help="Graph terminal node name for legacy imports (sets next=[] via as_node)") args = parser.parse_args() - # Load from environment if not provided - api_key = args.api_key or os.getenv("LANGSMITH_API_KEY") + # Resolve API keys + shared_key = args.api_key or os.getenv("LANGSMITH_API_KEY") + source_api_key = args.source_api_key or os.getenv("LANGSMITH_SOURCE_API_KEY") or shared_key + target_api_key = args.target_api_key or os.getenv("LANGSMITH_TARGET_API_KEY") or shared_key database_url = args.database_url or os.getenv("DATABASE_URL") - if not api_key: - console.print("[red]✗ LANGSMITH_API_KEY required[/red]") + # Parse metadata filter + metadata_filter = None + if args.metadata_filter: + try: + metadata_filter = json.loads(args.metadata_filter) + if not isinstance(metadata_filter, dict): + console.print("[red]✗ --metadata-filter must be a JSON object[/red]") + sys.exit(1) + except json.JSONDecodeError as e: + console.print(f"[red]✗ Invalid JSON: {e}[/red]") + sys.exit(1) + + # Validate keys + needs_source = bool(args.export_json or args.export_postgres or args.full or args.validate) + needs_target = bool(args.import_json or args.full or args.validate) + + if needs_source and not source_api_key: + console.print("[red]✗ Source API key required[/red]") + sys.exit(1) + if needs_target and not target_api_key: + console.print("[red]✗ Target API key required[/red]") sys.exit(1) - # Display header + # Header console.print(Panel.fit( - "[bold magenta]🔄 LangGraph Threads Export Tool[/bold magenta]", + "[bold magenta]LangGraph Threads Migration Tool[/bold magenta]", border_style="magenta", )) - + if needs_source and needs_target and source_api_key != target_api_key: + console.print("[cyan]ℹ Cross-org: separate API keys[/cyan]") if args.test_single: - console.print("[yellow]⚠ TEST MODE: Single thread only[/yellow]\n") + console.print("[yellow]⚠ TEST MODE: single thread[/yellow]") + console.print() - # Run appropriate command + # Dispatch try: if args.export_json: if not args.source_url: - console.print("[red]✗ --source-url required[/red]") - sys.exit(1) + console.print("[red]✗ --source-url required[/red]"); sys.exit(1) asyncio.run(run_export_json( - args.source_url, api_key, args.export_json, args.test_single + source_url=args.source_url, source_api_key=source_api_key, + output_file=args.export_json, test_single=args.test_single, + metadata_filter=metadata_filter, history_limit=args.history_limit, )) - elif args.export_postgres: if not args.source_url: - console.print("[red]✗ --source-url required[/red]") - sys.exit(1) + console.print("[red]✗ --source-url required[/red]"); sys.exit(1) if not database_url: - console.print("[red]✗ --database-url or DATABASE_URL required[/red]") - sys.exit(1) + console.print("[red]✗ --database-url required[/red]"); sys.exit(1) asyncio.run(run_export_postgres( - args.source_url, api_key, database_url, args.backup_file, args.test_single + source_url=args.source_url, source_api_key=source_api_key, + database_url=database_url, output_file=args.backup_file, + test_single=args.test_single, metadata_filter=metadata_filter, + history_limit=args.history_limit, )) - elif args.import_json: if not args.target_url: - console.print("[red]✗ --target-url required[/red]") - sys.exit(1) + console.print("[red]✗ --target-url required[/red]"); sys.exit(1) asyncio.run(run_import_json( - args.target_url, api_key, args.import_json, args.dry_run + target_url=args.target_url, target_api_key=target_api_key, + input_file=args.import_json, dry_run=args.dry_run, + legacy_terminal_node=args.legacy_terminal_node, )) - elif args.full: if not args.source_url or not args.target_url: - console.print("[red]✗ Both --source-url and --target-url required[/red]") - sys.exit(1) + console.print("[red]✗ Both --source-url and --target-url required[/red]"); sys.exit(1) asyncio.run(run_full_migration( - args.source_url, args.target_url, api_key, - args.backup_file, args.dry_run, args.test_single + source_url=args.source_url, target_url=args.target_url, + source_api_key=source_api_key, target_api_key=target_api_key, + backup_file=args.backup_file, dry_run=args.dry_run, + test_single=args.test_single, metadata_filter=metadata_filter, + history_limit=args.history_limit, + legacy_terminal_node=args.legacy_terminal_node, )) - elif args.validate: if not args.source_url or not args.target_url: - console.print("[red]✗ Both --source-url and --target-url required[/red]") - sys.exit(1) - asyncio.run(run_validate(args.source_url, args.target_url, api_key)) - + console.print("[red]✗ Both URLs required[/red]"); sys.exit(1) + asyncio.run(run_validate( + source_url=args.source_url, target_url=args.target_url, + source_api_key=source_api_key, target_api_key=target_api_key, + )) else: parser.print_help() - console.print("\n[yellow]⚠ Specify an action[/yellow]") sys.exit(1) except KeyboardInterrupt: console.print("\n[yellow]⚠ Interrupted[/yellow]") sys.exit(1) except Exception as e: - console.print(f"\n[red]✗ Error: {e}[/red]") + console.print(f"\n[red]✗ {e}[/red]") import traceback - console.print(traceback.format_exc()) + console.print(f"[dim]{traceback.format_exc()}[/dim]") sys.exit(1) diff --git a/requirements.txt b/requirements.txt index 81317e6..f51021f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -langgraph-sdk>=0.1.66 +langgraph-sdk>=0.3.4 rich>=13.9.0 python-dotenv>=1.0.0 sqlalchemy[asyncio]>=2.0.0 # Optional: for PostgreSQL export