mirror of
https://github.com/zebrajr/localGPT.git
synced 2025-12-06 12:20:53 +01:00
Removed unused import statements from various Python files to improve code clarity and reduce unnecessary dependencies.
443 lines
17 KiB
Python
443 lines
17 KiB
Python
import json
|
|
import threading
|
|
import time
|
|
from typing import Dict, List, Any
|
|
import logging
|
|
from urllib.parse import urlparse, parse_qs
|
|
import http.server
|
|
import socketserver
|
|
|
|
# Import the core logic and batch processing utilities
|
|
from rag_system.main import get_agent
|
|
from rag_system.utils.batch_processor import ProgressTracker, timer
|
|
|
|
# Set up logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Global progress tracking storage
|
|
ACTIVE_PROGRESS_SESSIONS: Dict[str, Dict[str, Any]] = {}
|
|
|
|
# --- Global Singleton for the RAG Agent ---
|
|
print("🧠 Initializing RAG Agent... (This may take a moment)")
|
|
RAG_AGENT = get_agent()
|
|
if RAG_AGENT is None:
|
|
print("❌ Critical error: RAG Agent could not be initialized. Exiting.")
|
|
exit(1)
|
|
print("✅ RAG Agent initialized successfully.")
|
|
|
|
class ServerSentEventsHandler:
|
|
"""Handler for Server-Sent Events (SSE) for real-time progress updates"""
|
|
|
|
active_connections: Dict[str, Any] = {}
|
|
|
|
@classmethod
|
|
def add_connection(cls, session_id: str, response_handler):
|
|
"""Add a new SSE connection"""
|
|
cls.active_connections[session_id] = response_handler
|
|
logger.info(f"SSE connection added for session: {session_id}")
|
|
|
|
@classmethod
|
|
def remove_connection(cls, session_id: str):
|
|
"""Remove an SSE connection"""
|
|
if session_id in cls.active_connections:
|
|
del cls.active_connections[session_id]
|
|
logger.info(f"SSE connection removed for session: {session_id}")
|
|
|
|
@classmethod
|
|
def send_event(cls, session_id: str, event_type: str, data: Dict[str, Any]):
|
|
"""Send an SSE event to a specific session"""
|
|
if session_id not in cls.active_connections:
|
|
return
|
|
|
|
try:
|
|
handler = cls.active_connections[session_id]
|
|
event_data = json.dumps(data)
|
|
message = f"event: {event_type}\ndata: {event_data}\n\n"
|
|
handler.wfile.write(message.encode('utf-8'))
|
|
handler.wfile.flush()
|
|
except Exception as e:
|
|
logger.error(f"Failed to send SSE event: {e}")
|
|
cls.remove_connection(session_id)
|
|
|
|
class RealtimeProgressTracker(ProgressTracker):
|
|
"""Enhanced ProgressTracker that sends updates via Server-Sent Events"""
|
|
|
|
def __init__(self, total_items: int, operation_name: str, session_id: str):
|
|
super().__init__(total_items, operation_name)
|
|
self.session_id = session_id
|
|
self.last_update = 0
|
|
self.update_interval = 1 # Update every 1 second
|
|
|
|
# Initialize session progress
|
|
ACTIVE_PROGRESS_SESSIONS[session_id] = {
|
|
"operation_name": operation_name,
|
|
"total_items": total_items,
|
|
"processed_items": 0,
|
|
"errors_encountered": 0,
|
|
"start_time": self.start_time,
|
|
"status": "running",
|
|
"current_step": "",
|
|
"eta_seconds": 0,
|
|
"throughput": 0,
|
|
"progress_percentage": 0
|
|
}
|
|
|
|
# Send initial progress update
|
|
self._send_progress_update()
|
|
|
|
def update(self, items_processed: int, errors: int = 0, current_step: str = ""):
|
|
"""Update progress and send notification"""
|
|
super().update(items_processed, errors)
|
|
|
|
# Update session data
|
|
session_data = ACTIVE_PROGRESS_SESSIONS.get(self.session_id)
|
|
if session_data:
|
|
session_data.update({
|
|
"processed_items": self.processed_items,
|
|
"errors_encountered": self.errors_encountered,
|
|
"current_step": current_step,
|
|
"progress_percentage": (self.processed_items / self.total_items) * 100,
|
|
})
|
|
|
|
# Calculate throughput and ETA
|
|
elapsed = time.time() - self.start_time
|
|
if elapsed > 0:
|
|
session_data["throughput"] = self.processed_items / elapsed
|
|
remaining = self.total_items - self.processed_items
|
|
session_data["eta_seconds"] = remaining / session_data["throughput"] if session_data["throughput"] > 0 else 0
|
|
|
|
# Send update if enough time has passed
|
|
current_time = time.time()
|
|
if current_time - self.last_update >= self.update_interval:
|
|
self._send_progress_update()
|
|
self.last_update = current_time
|
|
|
|
def finish(self):
|
|
"""Mark progress as finished and send final update"""
|
|
super().finish()
|
|
|
|
# Update session status
|
|
session_data = ACTIVE_PROGRESS_SESSIONS.get(self.session_id)
|
|
if session_data:
|
|
session_data.update({
|
|
"status": "completed",
|
|
"progress_percentage": 100,
|
|
"eta_seconds": 0
|
|
})
|
|
|
|
# Send final update
|
|
self._send_progress_update(final=True)
|
|
|
|
def _send_progress_update(self, final: bool = False):
|
|
"""Send progress update via Server-Sent Events"""
|
|
session_data = ACTIVE_PROGRESS_SESSIONS.get(self.session_id, {})
|
|
|
|
event_data = {
|
|
"session_id": self.session_id,
|
|
"progress": session_data.copy(),
|
|
"final": final,
|
|
"timestamp": time.time()
|
|
}
|
|
|
|
ServerSentEventsHandler.send_event(self.session_id, "progress", event_data)
|
|
|
|
def run_indexing_with_progress(file_paths: List[str], session_id: str):
|
|
"""Enhanced indexing function with real-time progress tracking"""
|
|
from rag_system.pipelines.indexing_pipeline import IndexingPipeline
|
|
from rag_system.utils.ollama_client import OllamaClient
|
|
import json
|
|
|
|
try:
|
|
# Send initial status
|
|
ServerSentEventsHandler.send_event(session_id, "status", {
|
|
"message": "Initializing indexing pipeline...",
|
|
"session_id": session_id
|
|
})
|
|
|
|
# Load configuration
|
|
config_file = "batch_indexing_config.json"
|
|
try:
|
|
with open(config_file, 'r') as f:
|
|
config = json.load(f)
|
|
except FileNotFoundError:
|
|
# Fallback to default config
|
|
config = {
|
|
"embedding_model_name": "Qwen/Qwen3-Embedding-0.6B",
|
|
"indexing": {
|
|
"embedding_batch_size": 50,
|
|
"enrichment_batch_size": 10,
|
|
"enable_progress_tracking": True
|
|
},
|
|
"contextual_enricher": {"enabled": True, "window_size": 1},
|
|
"retrievers": {
|
|
"dense": {"enabled": True, "lancedb_table_name": "default_text_table"},
|
|
"bm25": {"enabled": True, "index_name": "default_bm25_index"}
|
|
},
|
|
"storage": {
|
|
"chunk_store_path": "./index_store/chunks/chunks.pkl",
|
|
"lancedb_uri": "./index_store/lancedb",
|
|
"bm25_path": "./index_store/bm25"
|
|
}
|
|
}
|
|
|
|
# Initialize components
|
|
ollama_client = OllamaClient()
|
|
ollama_config = {
|
|
"generation_model": "llama3.2:1b",
|
|
"embedding_model": "mxbai-embed-large"
|
|
}
|
|
|
|
# Create enhanced pipeline
|
|
pipeline = IndexingPipeline(config, ollama_client, ollama_config)
|
|
|
|
# Create progress tracker for the overall process
|
|
total_steps = 6 # Rough estimate of pipeline steps
|
|
step_tracker = RealtimeProgressTracker(total_steps, "Document Indexing", session_id)
|
|
|
|
with timer("Complete Indexing Pipeline"):
|
|
try:
|
|
# Step 1: Document Processing
|
|
step_tracker.update(1, current_step="Processing documents...")
|
|
|
|
# Run the indexing pipeline
|
|
pipeline.run(file_paths)
|
|
|
|
# Update progress through the steps
|
|
step_tracker.update(1, current_step="Chunking completed...")
|
|
step_tracker.update(1, current_step="BM25 indexing completed...")
|
|
step_tracker.update(1, current_step="Contextual enrichment completed...")
|
|
step_tracker.update(1, current_step="Vector embeddings completed...")
|
|
step_tracker.update(1, current_step="Indexing finalized...")
|
|
|
|
step_tracker.finish()
|
|
|
|
# Send completion notification
|
|
ServerSentEventsHandler.send_event(session_id, "completion", {
|
|
"message": f"Successfully indexed {len(file_paths)} file(s)",
|
|
"file_count": len(file_paths),
|
|
"session_id": session_id
|
|
})
|
|
|
|
except Exception as e:
|
|
# Send error notification
|
|
ServerSentEventsHandler.send_event(session_id, "error", {
|
|
"message": str(e),
|
|
"session_id": session_id
|
|
})
|
|
raise
|
|
|
|
except Exception as e:
|
|
logger.error(f"Indexing failed for session {session_id}: {e}")
|
|
ServerSentEventsHandler.send_event(session_id, "error", {
|
|
"message": str(e),
|
|
"session_id": session_id
|
|
})
|
|
raise
|
|
|
|
class EnhancedRagApiHandler(http.server.BaseHTTPRequestHandler):
|
|
"""Enhanced API handler with progress tracking support"""
|
|
|
|
def do_OPTIONS(self):
|
|
"""Handle CORS preflight requests for frontend integration."""
|
|
self.send_response(200)
|
|
self.send_header('Access-Control-Allow-Origin', '*')
|
|
self.send_header('Access-Control-Allow-Methods', 'POST, GET, OPTIONS')
|
|
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
|
|
self.end_headers()
|
|
|
|
def do_GET(self):
|
|
"""Handle GET requests for progress status and SSE streams"""
|
|
parsed_path = urlparse(self.path)
|
|
|
|
if parsed_path.path == '/progress':
|
|
self.handle_progress_status()
|
|
elif parsed_path.path == '/stream':
|
|
self.handle_progress_stream()
|
|
else:
|
|
self.send_json_response({"error": "Not Found"}, status_code=404)
|
|
|
|
def do_POST(self):
|
|
"""Handle POST requests for chat and indexing."""
|
|
parsed_path = urlparse(self.path)
|
|
|
|
if parsed_path.path == '/chat':
|
|
self.handle_chat()
|
|
elif parsed_path.path == '/index':
|
|
self.handle_index_with_progress()
|
|
else:
|
|
self.send_json_response({"error": "Not Found"}, status_code=404)
|
|
|
|
def handle_chat(self):
|
|
"""Handles a chat query by calling the agentic RAG pipeline."""
|
|
try:
|
|
content_length = int(self.headers['Content-Length'])
|
|
post_data = self.rfile.read(content_length)
|
|
data = json.loads(post_data.decode('utf-8'))
|
|
|
|
query = data.get('query')
|
|
if not query:
|
|
self.send_json_response({"error": "Query is required"}, status_code=400)
|
|
return
|
|
|
|
# Use the single, persistent agent instance to run the query
|
|
result = RAG_AGENT.run(query)
|
|
|
|
# The result is a dict, so we need to dump it to a JSON string
|
|
self.send_json_response(result)
|
|
|
|
except json.JSONDecodeError:
|
|
self.send_json_response({"error": "Invalid JSON"}, status_code=400)
|
|
except Exception as e:
|
|
self.send_json_response({"error": f"Server error: {str(e)}"}, status_code=500)
|
|
|
|
def handle_index_with_progress(self):
|
|
"""Triggers the document indexing pipeline with real-time progress tracking."""
|
|
try:
|
|
content_length = int(self.headers['Content-Length'])
|
|
post_data = self.rfile.read(content_length)
|
|
data = json.loads(post_data.decode('utf-8'))
|
|
|
|
file_paths = data.get('file_paths')
|
|
session_id = data.get('session_id')
|
|
|
|
if not file_paths or not isinstance(file_paths, list):
|
|
self.send_json_response({
|
|
"error": "A 'file_paths' list is required."
|
|
}, status_code=400)
|
|
return
|
|
|
|
if not session_id:
|
|
self.send_json_response({
|
|
"error": "A 'session_id' is required for progress tracking."
|
|
}, status_code=400)
|
|
return
|
|
|
|
# Start indexing in a separate thread to avoid blocking
|
|
def run_indexing_thread():
|
|
try:
|
|
run_indexing_with_progress(file_paths, session_id)
|
|
except Exception as e:
|
|
logger.error(f"Indexing thread failed: {e}")
|
|
|
|
thread = threading.Thread(target=run_indexing_thread)
|
|
thread.daemon = True
|
|
thread.start()
|
|
|
|
# Return immediate response
|
|
self.send_json_response({
|
|
"message": f"Indexing started for {len(file_paths)} file(s)",
|
|
"session_id": session_id,
|
|
"status": "started",
|
|
"progress_stream_url": f"http://localhost:8001/stream?session_id={session_id}"
|
|
})
|
|
|
|
except json.JSONDecodeError:
|
|
self.send_json_response({"error": "Invalid JSON"}, status_code=400)
|
|
except Exception as e:
|
|
self.send_json_response({"error": f"Failed to start indexing: {str(e)}"}, status_code=500)
|
|
|
|
def handle_progress_status(self):
|
|
"""Handle GET requests for current progress status"""
|
|
parsed_url = urlparse(self.path)
|
|
params = parse_qs(parsed_url.query)
|
|
session_id = params.get('session_id', [None])[0]
|
|
|
|
if not session_id:
|
|
self.send_json_response({"error": "session_id is required"}, status_code=400)
|
|
return
|
|
|
|
progress_data = ACTIVE_PROGRESS_SESSIONS.get(session_id)
|
|
if not progress_data:
|
|
self.send_json_response({"error": "No active progress for this session"}, status_code=404)
|
|
return
|
|
|
|
self.send_json_response({
|
|
"session_id": session_id,
|
|
"progress": progress_data
|
|
})
|
|
|
|
def handle_progress_stream(self):
|
|
"""Handle Server-Sent Events stream for real-time progress"""
|
|
parsed_url = urlparse(self.path)
|
|
params = parse_qs(parsed_url.query)
|
|
session_id = params.get('session_id', [None])[0]
|
|
|
|
if not session_id:
|
|
self.send_response(400)
|
|
self.end_headers()
|
|
return
|
|
|
|
# Set up SSE headers
|
|
self.send_response(200)
|
|
self.send_header('Content-Type', 'text/event-stream')
|
|
self.send_header('Cache-Control', 'no-cache')
|
|
self.send_header('Connection', 'keep-alive')
|
|
self.send_header('Access-Control-Allow-Origin', '*')
|
|
self.end_headers()
|
|
|
|
# Add this connection to the SSE handler
|
|
ServerSentEventsHandler.add_connection(session_id, self)
|
|
|
|
# Send initial connection message
|
|
initial_message = json.dumps({
|
|
"session_id": session_id,
|
|
"message": "Progress stream connected",
|
|
"timestamp": time.time()
|
|
})
|
|
self.wfile.write(f"event: connected\ndata: {initial_message}\n\n".encode('utf-8'))
|
|
self.wfile.flush()
|
|
|
|
# Keep connection alive
|
|
try:
|
|
while session_id in ServerSentEventsHandler.active_connections:
|
|
time.sleep(1)
|
|
# Send heartbeat
|
|
heartbeat = json.dumps({"type": "heartbeat", "timestamp": time.time()})
|
|
self.wfile.write(f"event: heartbeat\ndata: {heartbeat}\n\n".encode('utf-8'))
|
|
self.wfile.flush()
|
|
except Exception as e:
|
|
logger.info(f"SSE connection closed for session {session_id}: {e}")
|
|
finally:
|
|
ServerSentEventsHandler.remove_connection(session_id)
|
|
|
|
def send_json_response(self, data, status_code=200):
|
|
"""Utility to send a JSON response with CORS headers."""
|
|
self.send_response(status_code)
|
|
self.send_header('Content-Type', 'application/json')
|
|
self.send_header('Access-Control-Allow-Origin', '*')
|
|
self.end_headers()
|
|
response = json.dumps(data, indent=2)
|
|
self.wfile.write(response.encode('utf-8'))
|
|
|
|
def start_enhanced_server(port=8000):
|
|
"""Start the enhanced API server with a reusable TCP socket."""
|
|
|
|
# Use a custom TCPServer that allows address reuse
|
|
class ReusableTCPServer(socketserver.TCPServer):
|
|
allow_reuse_address = True
|
|
|
|
with ReusableTCPServer(("", port), EnhancedRagApiHandler) as httpd:
|
|
print(f"🚀 Starting Enhanced RAG API server on port {port}")
|
|
print(f"💬 Chat endpoint: http://localhost:{port}/chat")
|
|
print(f"✨ Indexing endpoint: http://localhost:{port}/index")
|
|
print(f"📊 Progress endpoint: http://localhost:{port}/progress")
|
|
print(f"🌊 Progress stream: http://localhost:{port}/stream")
|
|
print(f"📈 Real-time progress tracking enabled via Server-Sent Events!")
|
|
httpd.serve_forever()
|
|
|
|
if __name__ == '__main__':
|
|
# Start the server on a dedicated thread
|
|
server_thread = threading.Thread(target=start_enhanced_server)
|
|
server_thread.daemon = True
|
|
server_thread.start()
|
|
|
|
print("🚀 Enhanced RAG API server with progress tracking is running.")
|
|
print("Press Ctrl+C to stop.")
|
|
|
|
# Keep the main thread alive
|
|
try:
|
|
while True:
|
|
time.sleep(1)
|
|
except KeyboardInterrupt:
|
|
print("\nStopping server...") |