diff --git a/services/whisper-service/src/api.py b/services/whisper-service/src/api.py new file mode 100644 index 0000000..33bdd6e --- /dev/null +++ b/services/whisper-service/src/api.py @@ -0,0 +1,416 @@ +import asyncio +import json +import logging +import os +import time +from pathlib import Path +from typing import Dict, List, Optional, Any + +import uvicorn +from fastapi import FastAPI, HTTPException, UploadFile, File +from pydantic import BaseModel +import redis.asyncio as redis +import psycopg2 +from psycopg2.extras import DictCursor +import structlog +from faster_whisper import WhisperModel +import torch + +# Configure structured logging +structlog.configure( + processors=[ + structlog.stdlib.filter_by_level, + structlog.stdlib.add_log_level, + structlog.stdlib.add_logger_name, + structlog.stdlib.PositionalArgumentsFormatter(), + structlog.processors.TimeStamper(fmt="iso"), + structlog.processors.StackInfoRenderer(), + structlog.processors.format_exc_info, + structlog.processors.JSONRenderer() + ], + context_class=dict, + logger_factory=structlog.stdlib.LoggerFactory(), + wrapper_class=structlog.stdlib.BoundLogger, + cache_logger_on_first_use=True, +) + +logger = structlog.get_logger() + +# Configuration +POSTGRES_URL = os.getenv('POSTGRES_URL') +REDIS_URL = os.getenv('REDIS_URL', 'redis://redis:6379') +WHISPER_MODEL = os.getenv('WHISPER_MODEL', 'large-v2') +MODEL_CACHE_DIR = '/app/models' +LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO') +DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' +COMPUTE_TYPE = 'float16' if DEVICE == 'cuda' else 'int8' + +# Set up logging level +logging.basicConfig(level=getattr(logging, LOG_LEVEL)) + +# Pydantic models +class TranscriptionRequest(BaseModel): + audio_path: str + recording_id: int + language: Optional[str] = None + +class TranscriptionResult(BaseModel): + text: str + language: str + language_confidence: float + segments: List[Dict[str, Any]] + processing_time_ms: int + model_used: str + +class HealthResponse(BaseModel): + status: str + device: str + model: str + gpu_available: bool + gpu_memory_mb: Optional[int] = None + +app = FastAPI(title="Whisper Transcription Service", version="2.0.0") + +class WhisperService: + def __init__(self): + self.model = None + self.redis_client = None + self.pg_connection = None + self.model_loading = False + + async def initialize(self): + """Initialize connections and load model""" + try: + # Connect to Redis + self.redis_client = redis.from_url(REDIS_URL) + await self.redis_client.ping() + + # Connect to PostgreSQL + self.pg_connection = psycopg2.connect(POSTGRES_URL) + + # Load Whisper model + await self.load_model() + + logger.info("Whisper service initialized successfully", + device=DEVICE, model=WHISPER_MODEL) + except Exception as e: + logger.error("Failed to initialize Whisper service", error=str(e)) + raise + + async def load_model(self): + """Load the Whisper model""" + if self.model_loading: + logger.warning("Model is already loading") + return + + self.model_loading = True + try: + logger.info("Loading Whisper model", model=WHISPER_MODEL, device=DEVICE) + + # Create model cache directory + os.makedirs(MODEL_CACHE_DIR, exist_ok=True) + + # Load model with optimizations + self.model = WhisperModel( + WHISPER_MODEL, + device=DEVICE, + compute_type=COMPUTE_TYPE, + download_root=MODEL_CACHE_DIR, + local_files_only=False # Allow downloading if not cached + ) + + # Log GPU memory usage if available + if DEVICE == 'cuda' and torch.cuda.is_available(): + gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**2 + used_memory = torch.cuda.memory_allocated(0) / 1024**2 + logger.info("GPU memory status", + total_mb=gpu_memory, used_mb=used_memory) + + logger.info("Whisper model loaded successfully") + + except Exception as e: + logger.error("Failed to load Whisper model", error=str(e)) + raise + finally: + self.model_loading = False + + async def cleanup(self): + """Clean up connections""" + if self.redis_client: + await self.redis_client.close() + if self.pg_connection: + self.pg_connection.close() + + async def transcribe_audio(self, audio_path: str, recording_id: int, + language: Optional[str] = None) -> TranscriptionResult: + """Transcribe audio file""" + if not self.model: + raise HTTPException(status_code=503, detail="Model not loaded") + + if not os.path.exists(audio_path): + raise HTTPException(status_code=404, detail="Audio file not found") + + start_time = time.time() + + try: + logger.info("Starting transcription", + recording_id=recording_id, audio_path=audio_path) + + # Transcribe with faster-whisper + segments_generator, info = self.model.transcribe( + audio_path, + language=language, + beam_size=5, + best_of=5, + temperature=0.0, + condition_on_previous_text=False, + vad_filter=True, + vad_parameters=dict(min_silence_duration_ms=500) + ) + + # Convert segments to list and extract text + segments = [] + full_text_parts = [] + + for segment in segments_generator: + segment_dict = { + 'start': segment.start, + 'end': segment.end, + 'text': segment.text.strip(), + 'words': [ + { + 'start': word.start, + 'end': word.end, + 'word': word.word, + 'probability': word.probability + } for word in segment.words + ] if hasattr(segment, 'words') and segment.words else [] + } + segments.append(segment_dict) + full_text_parts.append(segment.text.strip()) + + # Combine all text + full_text = ' '.join(full_text_parts).strip() + + processing_time_ms = int((time.time() - start_time) * 1000) + + # Create result + result = TranscriptionResult( + text=full_text, + language=info.language, + language_confidence=info.language_probability, + segments=segments, + processing_time_ms=processing_time_ms, + model_used=WHISPER_MODEL + ) + + logger.info("Transcription completed", + recording_id=recording_id, + language=info.language, + confidence=info.language_probability, + text_length=len(full_text), + processing_time_ms=processing_time_ms) + + return result + + except Exception as e: + logger.error("Error during transcription", + error=str(e), recording_id=recording_id) + raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}") + + async def update_transcription_record(self, recording_id: int, result: TranscriptionResult): + """Update transcription record in database""" + try: + with self.pg_connection.cursor() as cursor: + # Insert transcription record + cursor.execute(""" + INSERT INTO transcriptions + (recording_id, detected_language, language_confidence, transcription_text, + word_count, processing_time_ms, whisper_model, status) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s) + RETURNING id, uuid + """, ( + recording_id, + result.language, + result.language_confidence, + result.text, + len(result.text.split()) if result.text else 0, + result.processing_time_ms, + result.model_used, + 'completed' + )) + + transcription_record = cursor.fetchone() + self.pg_connection.commit() + + logger.info("Created transcription record", + transcription_id=transcription_record[0], + recording_id=recording_id) + + return transcription_record[0] # Return transcription ID + + except Exception as e: + logger.error("Failed to update transcription record", + error=str(e), recording_id=recording_id) + self.pg_connection.rollback() + raise + + async def publish_event(self, event: str, data: Dict[str, Any]): + """Publish event to Redis""" + try: + message = { + 'event': event, + 'data': data, + 'timestamp': time.time(), + 'service': 'whisper-service' + } + + await self.redis_client.publish('voice-translator-events', json.dumps(message)) + logger.debug("Published event", event=event, data=data) + + except Exception as e: + logger.error("Failed to publish event", error=str(e), event=event) + + async def process_audio_event(self, message_data: Dict[str, Any]): + """Process audio_processed event""" + recording_id = message_data.get('recordingId') + processed_path = message_data.get('processedPath') + + if not recording_id or not processed_path: + logger.error("Invalid audio_processed message", data=message_data) + return + + try: + # Transcribe the audio + result = await self.transcribe_audio(processed_path, recording_id) + + # Update database + transcription_id = await self.update_transcription_record(recording_id, result) + + # Publish transcription completed event + await self.publish_event('transcription_completed', { + 'recordingId': recording_id, + 'transcriptionId': transcription_id, + 'text': result.text, + 'language': result.language, + 'languageConfidence': result.language_confidence, + 'processingTimeMs': result.processing_time_ms + }) + + except Exception as e: + logger.error("Error processing audio event", + error=str(e), recording_id=recording_id) + + async def listen_for_events(self): + """Listen for audio processing events""" + logger.info("Starting to listen for audio events") + + pubsub = self.redis_client.pubsub() + await pubsub.subscribe('voice-translator-events') + + try: + async for message in pubsub.listen(): + if message['type'] == 'message': + try: + data = json.loads(message['data']) + + # Process audio_processed events + if data.get('event') == 'audio_processed': + await self.process_audio_event(data['data']) + + except json.JSONDecodeError: + logger.error("Invalid JSON in message", data=message['data']) + except Exception as e: + logger.error("Error processing message", error=str(e)) + + except Exception as e: + logger.error("Error in message listener", error=str(e)) + finally: + await pubsub.unsubscribe('voice-translator-events') + +# Global service instance +whisper_service = WhisperService() + +@app.on_event("startup") +async def startup_event(): + """Initialize the service on startup""" + await whisper_service.initialize() + + # Start background task to listen for events + asyncio.create_task(whisper_service.listen_for_events()) + +@app.on_event("shutdown") +async def shutdown_event(): + """Cleanup on shutdown""" + await whisper_service.cleanup() + +@app.get("/health", response_model=HealthResponse) +async def health_check(): + """Health check endpoint""" + gpu_memory = None + if torch.cuda.is_available(): + gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**2 + + return HealthResponse( + status="healthy" if whisper_service.model else "loading", + device=DEVICE, + model=WHISPER_MODEL, + gpu_available=torch.cuda.is_available(), + gpu_memory_mb=int(gpu_memory) if gpu_memory else None + ) + +@app.post("/transcribe", response_model=TranscriptionResult) +async def transcribe_endpoint(request: TranscriptionRequest): + """Manual transcription endpoint""" + return await whisper_service.transcribe_audio( + request.audio_path, + request.recording_id, + request.language + ) + +@app.post("/transcribe/upload") +async def transcribe_upload(file: UploadFile = File(...), recording_id: int = 0): + """Upload and transcribe audio file""" + if not file.filename.endswith(('.wav', '.mp3', '.m4a', '.flac')): + raise HTTPException(status_code=400, detail="Unsupported audio format") + + # Save uploaded file temporarily + temp_path = f"/tmp/{file.filename}" + + try: + with open(temp_path, "wb") as f: + content = await file.read() + f.write(content) + + # Transcribe + result = await whisper_service.transcribe_audio(temp_path, recording_id) + + return result + + finally: + # Clean up temp file + if os.path.exists(temp_path): + os.remove(temp_path) + +@app.get("/models") +async def list_models(): + """List available Whisper models""" + return { + "current_model": WHISPER_MODEL, + "device": DEVICE, + "compute_type": COMPUTE_TYPE, + "available_models": [ + "tiny", "tiny.en", "base", "base.en", + "small", "small.en", "medium", "medium.en", + "large", "large-v1", "large-v2", "large-v3" + ] + } + +if __name__ == "__main__": + uvicorn.run( + "api:app", + host="0.0.0.0", + port=8000, + log_level=LOG_LEVEL.lower(), + access_log=True + )