Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model: "gpt-4o-2024-11-20"
toc_check_page_num: 20
max_page_num_each_node: 10
max_token_num_each_node: 20000
if_add_node_id: true
if_add_node_summary: true
if_add_doc_description: false
if_add_node_text: false
91 changes: 91 additions & 0 deletions pageindex/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import os
import yaml
from pathlib import Path
from typing import Any, Dict, Optional, Union
from pydantic import BaseModel, Field, ValidationError

class PageIndexConfig(BaseModel):
"""
Configuration schema for PageIndex.
"""
model: str = Field(default="gpt-4o", description="LLM model to use")

# PDF Processing
toc_check_page_num: int = Field(default=3, description="Number of pages to check for TOC")
max_page_num_each_node: int = Field(default=5, description="Maximum pages per leaf node")
max_token_num_each_node: int = Field(default=4000, description="Max tokens per node") # Approx

# Enrichment
if_add_node_id: bool = Field(default=True, description="Add unique ID to nodes")
if_add_node_summary: bool = Field(default=True, description="Generate summary for nodes")
if_add_doc_description: bool = Field(default=True, description="Generate doc-level description")
if_add_node_text: bool = Field(default=True, description="Keep raw text in nodes")

# Tree Optimization
if_thinning: bool = Field(default=True, description="Merge small adjacent nodes")
thinning_threshold: int = Field(default=500, description="Token threshold for merging")
summary_token_threshold: int = Field(default=200, description="Min tokens required to trigger summary generation")

# Additional
api_key: Optional[str] = Field(default=None, description="OpenAI API Key (optional, prefers env var)")

class Config:
arbitrary_types_allowed = True
extra = "forbid"


class ConfigLoader:
def __init__(self, default_path: Optional[Union[str, Path]] = None):
if default_path is None:
env_path = os.getenv("PAGEINDEX_CONFIG")
if env_path:
default_path = Path(env_path)
else:
cwd_path = Path.cwd() / "config.yaml"
repo_path = Path(__file__).resolve().parents[1] / "config.yaml"
default_path = cwd_path if cwd_path.exists() else repo_path

self.default_path = default_path
self._default_dict = self._load_yaml(default_path) if default_path else {}

@staticmethod
def _load_yaml(path: Optional[Path]) -> Dict[str, Any]:
if not path or not path.exists():
return {}
try:
with open(path, "r", encoding="utf-8") as f:
return yaml.safe_load(f) or {}
except Exception as e:
print(f"Warning: Failed to load config from {path}: {e}")
return {}

def load(self, user_opt: Optional[Union[Dict[str, Any], Any]] = None) -> PageIndexConfig:
"""
Load configuration, merging defaults with user overrides and validating via Pydantic.

Args:
user_opt: Dictionary or object with overrides.

Returns:
PageIndexConfig: Validated configuration object.
"""
user_dict: Dict[str, Any] = {}
if user_opt is None:
pass
elif hasattr(user_opt, '__dict__'):
# Handle SimpleNamespace or other objects
user_dict = {k: v for k, v in vars(user_opt).items() if v is not None}
elif isinstance(user_opt, dict):
user_dict = {k: v for k, v in user_opt.items() if v is not None}
else:
raise TypeError(f"user_opt must be dict or object, got {type(user_opt)}")

# Merge defaults and user overrides
# Pydantic accepts kwargs, efficiently merging
merged_data = {**self._default_dict, **user_dict}

try:
return PageIndexConfig(**merged_data)
except ValidationError as e:
# Re-raise nicely or log
raise ValueError(f"Configuration validation failed: {e}")
Empty file added pageindex/core/__init__.py
Empty file.
245 changes: 245 additions & 0 deletions pageindex/core/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
import tiktoken
import openai
import logging
import os
import time
import json
import asyncio
from typing import Optional, List, Dict, Any, Union, Tuple
from dotenv import load_dotenv

load_dotenv()

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("CHATGPT_API_KEY")

def count_tokens(text: Optional[str], model: str = "gpt-4o") -> int:
"""
Count the number of tokens in a text string using the specified model's encoding.

Args:
text (Optional[str]): The text to encode. If None, returns 0.
model (str): The model name to use for encoding. Defaults to "gpt-4o".

Returns:
int: The number of tokens.
"""
if not text:
return 0
try:
enc = tiktoken.encoding_for_model(model)
except KeyError:
# Fallback for newer or unknown models
enc = tiktoken.get_encoding("cl100k_base")
tokens = enc.encode(text)
return len(tokens)

def ChatGPT_API_with_finish_reason(
model: str,
prompt: str,
api_key: Optional[str] = OPENAI_API_KEY,
chat_history: Optional[List[Dict[str, str]]] = None
) -> Tuple[str, str]:
"""
Call OpenAI Chat Completion API and return content along with finish reason.

Args:
model (str): The model name (e.g., "gpt-4o").
prompt (str): The user prompt.
api_key (Optional[str]): OpenAI API key. Defaults to env var.
chat_history (Optional[List[Dict[str, str]]]): Previous messages for context.

Returns:
Tuple[str, str]: A tuple containing (content, finish_reason).
Returns ("Error", "error") if max retries reached.
"""
max_retries = 10
if not api_key:
logging.error("No API key provided.")
return "Error", "missing_api_key"

client = openai.OpenAI(api_key=api_key)
for i in range(max_retries):
try:
if chat_history:
messages = chat_history.copy() # Avoid modifying original list if passed by ref (shallow copy enough for append)
messages.append({"role": "user", "content": prompt})
else:
messages = [{"role": "user", "content": prompt}]

response = client.chat.completions.create(
model=model,
messages=messages,
temperature=0,
)

content = response.choices[0].message.content or ""
finish_reason = response.choices[0].finish_reason

if finish_reason == "length":
return content, "max_output_reached"
else:
return content, "finished"

except Exception as e:
print('************* Retrying *************')
logging.error(f"Error: {e}")
if i < max_retries - 1:
time.sleep(1)
else:
logging.error('Max retries reached for prompt: ' + prompt[:50] + '...')
return "Error", "error"
return "Error", "max_retries"

def ChatGPT_API(
model: str,
prompt: str,
api_key: Optional[str] = OPENAI_API_KEY,
chat_history: Optional[List[Dict[str, str]]] = None
) -> str:
"""
Call OpenAI Chat Completion API and return the content string.

Args:
model (str): The model name.
prompt (str): The user prompt.
api_key (Optional[str]): OpenAI API key.
chat_history (Optional[List[Dict[str, str]]]): Previous messages.

Returns:
str: The response content, or "Error" if failed.
"""
max_retries = 10
if not api_key:
logging.error("No API key provided.")
return "Error"

client = openai.OpenAI(api_key=api_key)
for i in range(max_retries):
try:
if chat_history:
messages = chat_history.copy()
messages.append({"role": "user", "content": prompt})
else:
messages = [{"role": "user", "content": prompt}]

response = client.chat.completions.create(
model=model,
messages=messages,
temperature=0,
)

return response.choices[0].message.content or ""
except Exception as e:
print('************* Retrying *************')
logging.error(f"Error: {e}")
if i < max_retries - 1:
time.sleep(1)
else:
logging.error('Max retries reached for prompt: ' + prompt[:50] + '...')
return "Error"
return "Error"

async def ChatGPT_API_async(
model: str,
prompt: str,
api_key: Optional[str] = OPENAI_API_KEY
) -> str:
"""
Asynchronously call OpenAI Chat Completion API.

Args:
model (str): The model name.
prompt (str): The user prompt.
api_key (Optional[str]): OpenAI API key.

Returns:
str: The response content, or "Error" if failed.
"""
max_retries = 10
if not api_key:
logging.error("No API key provided.")
return "Error"

messages = [{"role": "user", "content": prompt}]
for i in range(max_retries):
try:
async with openai.AsyncOpenAI(api_key=api_key) as client:
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=0,
)
return response.choices[0].message.content or ""
except Exception as e:
print('************* Retrying *************')
logging.error(f"Error: {e}")
if i < max_retries - 1:
await asyncio.sleep(1)
else:
logging.error('Max retries reached for prompt: ' + prompt[:50] + '...')
return "Error"
return "Error"

def get_json_content(response: str) -> str:
"""
Extract content inside markdown JSON code blocks.

Args:
response (str): The full raw response string.

Returns:
str: The extracted JSON string stripped of markers.
"""
start_idx = response.find("```json")
if start_idx != -1:
start_idx += 7
response = response[start_idx:]

end_idx = response.rfind("```")
if end_idx != -1:
response = response[:end_idx]

json_content = response.strip()
return json_content

def extract_json(content: str) -> Union[Dict[str, Any], List[Any]]:
"""
Robustly extract and parse JSON from a string, handling common LLM formatting issues.

Args:
content (str): The text containing JSON.

Returns:
Union[Dict, List]: The parsed JSON object or empty dict/list on failure.
"""
try:
# First, try to extract JSON enclosed within ```json and ```
start_idx = content.find("```json")
if start_idx != -1:
start_idx += 7 # Adjust index to start after the delimiter
end_idx = content.rfind("```")
json_content = content[start_idx:end_idx].strip()
else:
# If no delimiters, assume entire content could be JSON
json_content = content.strip()

# Clean up common issues that might cause parsing errors
json_content = json_content.replace('None', 'null') # Replace Python None with JSON null
json_content = json_content.replace('\n', ' ').replace('\r', ' ') # Remove newlines
json_content = ' '.join(json_content.split()) # Normalize whitespace

# Attempt to parse and return the JSON object
return json.loads(json_content)
except json.JSONDecodeError as e:
logging.error(f"Failed to extract JSON: {e}")
# Try to clean up the content further if initial parsing fails
try:
# Remove any trailing commas before closing brackets/braces
json_content = json_content.replace(',]', ']').replace(',}', '}')
return json.loads(json_content)
except:
logging.error("Failed to parse JSON even after cleanup")
return {}
except Exception as e:
logging.error(f"Unexpected error while extracting JSON: {e}")
return {}
Loading
Loading