diff --git a/services/translator/src/api.py b/services/translator/src/api.py new file mode 100644 index 0000000..825f180 --- /dev/null +++ b/services/translator/src/api.py @@ -0,0 +1,566 @@ +import asyncio +import json +import logging +import os +import time +from typing import Dict, List, Optional, Any +from enum import Enum + +import uvicorn +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +import redis.asyncio as redis +import psycopg2 +from psycopg2.extras import DictCursor +import structlog +import torch +from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer +from googletrans import Translator as GoogleTranslator +from deep_translator import GoogleTranslator as DeepGoogleTranslator +import httpx +from cachetools import TTLCache + +# Configure structured logging +structlog.configure( + processors=[ + structlog.stdlib.filter_by_level, + structlog.stdlib.add_log_level, + structlog.stdlib.add_logger_name, + structlog.stdlib.PositionalArgumentsFormatter(), + structlog.processors.TimeStamper(fmt="iso"), + structlog.processors.StackInfoRenderer(), + structlog.processors.format_exc_info, + structlog.processors.JSONRenderer() + ], + context_class=dict, + logger_factory=structlog.stdlib.LoggerFactory(), + wrapper_class=structlog.stdlib.BoundLogger, + cache_logger_on_first_use=True, +) + +logger = structlog.get_logger() + +# Configuration +POSTGRES_URL = os.getenv('POSTGRES_URL') +REDIS_URL = os.getenv('REDIS_URL', 'redis://redis:6379') +GOOGLE_TRANSLATE_API_KEY = os.getenv('GOOGLE_TRANSLATE_API_KEY') +MODEL_CACHE_DIR = '/app/models' +LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO') +DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' + +# Set up logging level +logging.basicConfig(level=getattr(logging, LOG_LEVEL)) + +# Language mappings +class Language(str, Enum): + ENGLISH = "en" + GERMAN = "de" + KOREAN = "ko" + +# NLLB language code mappings +NLLB_LANG_MAP = { + 'en': 'eng_Latn', + 'de': 'deu_Latn', + 'ko': 'kor_Hang', + 'es': 'spa_Latn', + 'fr': 'fra_Latn', + 'it': 'ita_Latn', + 'pt': 'por_Latn', + 'ru': 'rus_Cyrl', + 'ja': 'jpn_Jpan', + 'zh': 'zho_Hans', + 'ar': 'arb_Arab' +} + +# Language display names and flags +LANGUAGE_INFO = { + 'en': {'name': 'English', 'flag': '🇺🇸'}, + 'de': {'name': 'German', 'flag': '🇩🇪'}, + 'ko': {'name': 'Korean', 'flag': '🇰🇷'}, + 'es': {'name': 'Spanish', 'flag': '🇪🇸'}, + 'fr': {'name': 'French', 'flag': '🇫🇷'}, + 'it': {'name': 'Italian', 'flag': '🇮🇹'}, + 'pt': {'name': 'Portuguese', 'flag': '🇵🇹'}, + 'ru': {'name': 'Russian', 'flag': '🇷🇺'}, + 'ja': {'name': 'Japanese', 'flag': '🇯🇵'}, + 'zh': {'name': 'Chinese', 'flag': '🇨🇳'}, + 'ar': {'name': 'Arabic', 'flag': '🇸🇦'} +} + +# Pydantic models +class TranslationRequest(BaseModel): + text: str + source_language: str + target_language: str + transcription_id: Optional[int] = None + +class BulkTranslationRequest(BaseModel): + text: str + source_language: str + target_languages: List[str] + transcription_id: Optional[int] = None + +class TranslationResult(BaseModel): + translated_text: str + source_language: str + target_language: str + confidence_score: Optional[float] = None + processing_time_ms: int + service_used: str + +class BulkTranslationResult(BaseModel): + source_text: str + source_language: str + translations: Dict[str, TranslationResult] + total_processing_time_ms: int + +class HealthResponse(BaseModel): + status: str + local_model_loaded: bool + device: str + supported_languages: List[str] + google_api_available: bool + +app = FastAPI(title="Translation Service", version="2.0.0") + +class TranslationService: + def __init__(self): + self.nllb_model = None + self.nllb_tokenizer = None + self.google_translator = None + self.deep_translator = None + self.redis_client = None + self.pg_connection = None + self.model_loading = False + + # Translation cache (TTL: 1 hour) + self.translation_cache = TTLCache(maxsize=1000, ttl=3600) + + async def initialize(self): + """Initialize connections and load models""" + try: + # Connect to Redis + self.redis_client = redis.from_url(REDIS_URL) + await self.redis_client.ping() + + # Connect to PostgreSQL + self.pg_connection = psycopg2.connect(POSTGRES_URL) + + # Initialize Google Translators + self.google_translator = GoogleTranslator() + self.deep_translator = DeepGoogleTranslator(api_key=GOOGLE_TRANSLATE_API_KEY) if GOOGLE_TRANSLATE_API_KEY else None + + # Load local NLLB model + await self.load_local_model() + + logger.info("Translation service initialized successfully", device=DEVICE) + except Exception as e: + logger.error("Failed to initialize translation service", error=str(e)) + raise + + async def load_local_model(self): + """Load the local NLLB model""" + if self.model_loading: + logger.warning("Model is already loading") + return + + self.model_loading = True + try: + logger.info("Loading NLLB translation model", device=DEVICE) + + # Create model cache directory + os.makedirs(MODEL_CACHE_DIR, exist_ok=True) + + # Load NLLB-200 distilled model (600M parameters - good balance of speed/quality) + model_name = "facebook/nllb-200-distilled-600M" + + self.nllb_tokenizer = M2M100Tokenizer.from_pretrained( + model_name, + cache_dir=MODEL_CACHE_DIR + ) + + self.nllb_model = M2M100ForConditionalGeneration.from_pretrained( + model_name, + cache_dir=MODEL_CACHE_DIR, + torch_dtype=torch.float16 if DEVICE == 'cuda' else torch.float32 + ).to(DEVICE) + + if DEVICE == 'cuda': + self.nllb_model.half() # Use half precision for faster inference + + logger.info("NLLB model loaded successfully") + + except Exception as e: + logger.error("Failed to load NLLB model", error=str(e)) + # Continue without local model - will use web APIs + finally: + self.model_loading = False + + async def cleanup(self): + """Clean up connections""" + if self.redis_client: + await self.redis_client.close() + if self.pg_connection: + self.pg_connection.close() + + def get_cache_key(self, text: str, source_lang: str, target_lang: str) -> str: + """Generate cache key for translation""" + return f"{source_lang}:{target_lang}:{hash(text)}" + + async def translate_with_nllb(self, text: str, source_lang: str, target_lang: str) -> Optional[str]: + """Translate using local NLLB model""" + if not self.nllb_model or not self.nllb_tokenizer: + return None + + try: + # Map language codes to NLLB format + source_nllb = NLLB_LANG_MAP.get(source_lang) + target_nllb = NLLB_LANG_MAP.get(target_lang) + + if not source_nllb or not target_nllb: + logger.warning("Language not supported by NLLB", + source=source_lang, target=target_lang) + return None + + # Set source language + self.nllb_tokenizer.src_lang = source_nllb + + # Tokenize input text + encoded = self.nllb_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) + + if DEVICE == 'cuda': + encoded = {k: v.to(DEVICE) for k, v in encoded.items()} + + # Generate translation + generated_tokens = self.nllb_model.generate( + **encoded, + forced_bos_token_id=self.nllb_tokenizer.lang_code_to_id[target_nllb], + max_length=512, + num_beams=4, + length_penalty=1.0, + early_stopping=True + ) + + # Decode translation + translated_text = self.nllb_tokenizer.batch_decode( + generated_tokens, skip_special_tokens=True + )[0] + + return translated_text.strip() + + except Exception as e: + logger.error("Error in NLLB translation", + error=str(e), source=source_lang, target=target_lang) + return None + + async def translate_with_google(self, text: str, source_lang: str, target_lang: str) -> Optional[str]: + """Translate using Google Translate API""" + try: + # Try premium API first if available + if self.deep_translator: + try: + result = self.deep_translator.translate(text, target=target_lang, source=source_lang) + return result + except Exception as e: + logger.warning("Premium Google API failed, trying free version", error=str(e)) + + # Fallback to free Google Translate + result = self.google_translator.translate(text, dest=target_lang, src=source_lang) + return result.text + + except Exception as e: + logger.error("Error in Google translation", + error=str(e), source=source_lang, target=target_lang) + return None + + async def translate_text(self, text: str, source_lang: str, target_lang: str, + transcription_id: Optional[int] = None) -> TranslationResult: + """Translate text with fallback strategy""" + start_time = time.time() + + # Check cache first + cache_key = self.get_cache_key(text, source_lang, target_lang) + if cache_key in self.translation_cache: + cached_result = self.translation_cache[cache_key] + logger.debug("Using cached translation", source=source_lang, target=target_lang) + return TranslationResult( + translated_text=cached_result, + source_language=source_lang, + target_language=target_lang, + processing_time_ms=int((time.time() - start_time) * 1000), + service_used="cache" + ) + + translated_text = None + service_used = "none" + + # Try local NLLB model first + if self.nllb_model: + translated_text = await self.translate_with_nllb(text, source_lang, target_lang) + if translated_text: + service_used = "nllb_local" + + # Fallback to Google Translate if local model failed + if not translated_text: + translated_text = await self.translate_with_google(text, source_lang, target_lang) + if translated_text: + service_used = "google_translate" + + if not translated_text: + raise HTTPException(status_code=500, detail="All translation services failed") + + processing_time_ms = int((time.time() - start_time) * 1000) + + # Cache the result + self.translation_cache[cache_key] = translated_text + + result = TranslationResult( + translated_text=translated_text, + source_language=source_lang, + target_language=target_lang, + processing_time_ms=processing_time_ms, + service_used=service_used + ) + + # Update database if transcription_id provided + if transcription_id: + await self.save_translation_record(transcription_id, result) + + logger.info("Translation completed", + source=source_lang, target=target_lang, + service=service_used, time_ms=processing_time_ms) + + return result + + async def bulk_translate(self, text: str, source_lang: str, target_languages: List[str], + transcription_id: Optional[int] = None) -> BulkTranslationResult: + """Translate to multiple target languages""" + start_time = time.time() + translations = {} + + # Translate to each target language + for target_lang in target_languages: + if source_lang == target_lang: + # Skip translation if source equals target + translations[target_lang] = TranslationResult( + translated_text=text, + source_language=source_lang, + target_language=target_lang, + processing_time_ms=0, + service_used="no_translation_needed" + ) + else: + try: + translation = await self.translate_text(text, source_lang, target_lang, transcription_id) + translations[target_lang] = translation + except Exception as e: + logger.error("Failed to translate to language", + target=target_lang, error=str(e)) + # Create error result + translations[target_lang] = TranslationResult( + translated_text=f"[Translation failed: {str(e)}]", + source_language=source_lang, + target_language=target_lang, + processing_time_ms=0, + service_used="error" + ) + + total_time_ms = int((time.time() - start_time) * 1000) + + return BulkTranslationResult( + source_text=text, + source_language=source_lang, + translations=translations, + total_processing_time_ms=total_time_ms + ) + + async def save_translation_record(self, transcription_id: int, result: TranslationResult): + """Save translation to database""" + try: + with self.pg_connection.cursor() as cursor: + cursor.execute(""" + INSERT INTO translations + (transcription_id, target_language, translated_text, translation_service, + processing_time_ms, status) + VALUES (%s, %s, %s, %s, %s, %s) + RETURNING id + """, ( + transcription_id, + result.target_language, + result.translated_text, + result.service_used, + result.processing_time_ms, + 'completed' + )) + + translation_id = cursor.fetchone()[0] + self.pg_connection.commit() + + logger.debug("Saved translation record", + translation_id=translation_id, transcription_id=transcription_id) + + except Exception as e: + logger.error("Failed to save translation record", + error=str(e), transcription_id=transcription_id) + self.pg_connection.rollback() + + async def publish_event(self, event: str, data: Dict[str, Any]): + """Publish event to Redis""" + try: + message = { + 'event': event, + 'data': data, + 'timestamp': time.time(), + 'service': 'translator' + } + + await self.redis_client.publish('voice-translator-events', json.dumps(message)) + logger.debug("Published event", event=event, data=data) + + except Exception as e: + logger.error("Failed to publish event", error=str(e), event=event) + + async def process_transcription_event(self, message_data: Dict[str, Any]): + """Process transcription_completed event""" + transcription_id = message_data.get('transcriptionId') + text = message_data.get('text') + detected_language = message_data.get('language') + + if not all([transcription_id, text, detected_language]): + logger.error("Invalid transcription_completed message", data=message_data) + return + + try: + # Define target languages (always English, German, Korean) + target_languages = ['en', 'de', 'ko'] + + # Remove source language from targets (no need to translate to itself) + if detected_language in target_languages: + target_languages = [lang for lang in target_languages if lang != detected_language] + + # Perform bulk translation + if target_languages: + bulk_result = await self.bulk_translate( + text, detected_language, target_languages, transcription_id + ) + + # Publish translation completed event + await self.publish_event('translations_completed', { + 'transcriptionId': transcription_id, + 'sourceLanguage': detected_language, + 'sourceText': text, + 'translations': { + lang: result.translated_text + for lang, result in bulk_result.translations.items() + }, + 'processingTimeMs': bulk_result.total_processing_time_ms + }) + + logger.info("Translations completed", + transcription_id=transcription_id, + target_languages=target_languages) + else: + # No translation needed, just publish the original + await self.publish_event('translations_completed', { + 'transcriptionId': transcription_id, + 'sourceLanguage': detected_language, + 'sourceText': text, + 'translations': {}, + 'processingTimeMs': 0 + }) + + except Exception as e: + logger.error("Error processing transcription event", + error=str(e), transcription_id=transcription_id) + + async def listen_for_events(self): + """Listen for transcription events""" + logger.info("Starting to listen for transcription events") + + pubsub = self.redis_client.pubsub() + await pubsub.subscribe('voice-translator-events') + + try: + async for message in pubsub.listen(): + if message['type'] == 'message': + try: + data = json.loads(message['data']) + + # Process transcription_completed events + if data.get('event') == 'transcription_completed': + await self.process_transcription_event(data['data']) + + except json.JSONDecodeError: + logger.error("Invalid JSON in message", data=message['data']) + except Exception as e: + logger.error("Error processing message", error=str(e)) + + except Exception as e: + logger.error("Error in message listener", error=str(e)) + finally: + await pubsub.unsubscribe('voice-translator-events') + +# Global service instance +translation_service = TranslationService() + +@app.on_event("startup") +async def startup_event(): + """Initialize the service on startup""" + await translation_service.initialize() + + # Start background task to listen for events + asyncio.create_task(translation_service.listen_for_events()) + +@app.on_event("shutdown") +async def shutdown_event(): + """Cleanup on shutdown""" + await translation_service.cleanup() + +@app.get("/health", response_model=HealthResponse) +async def health_check(): + """Health check endpoint""" + return HealthResponse( + status="healthy", + local_model_loaded=translation_service.nllb_model is not None, + device=DEVICE, + supported_languages=list(NLLB_LANG_MAP.keys()), + google_api_available=translation_service.deep_translator is not None + ) + +@app.post("/translate", response_model=TranslationResult) +async def translate_endpoint(request: TranslationRequest): + """Single translation endpoint""" + return await translation_service.translate_text( + request.text, + request.source_language, + request.target_language, + request.transcription_id + ) + +@app.post("/translate/bulk", response_model=BulkTranslationResult) +async def bulk_translate_endpoint(request: BulkTranslationRequest): + """Bulk translation endpoint""" + return await translation_service.bulk_translate( + request.text, + request.source_language, + request.target_languages, + request.transcription_id + ) + +@app.get("/languages") +async def list_languages(): + """List supported languages""" + return { + "supported_languages": LANGUAGE_INFO, + "local_model_languages": list(NLLB_LANG_MAP.keys()), + "primary_languages": ["en", "de", "ko"] + } + +if __name__ == "__main__": + uvicorn.run( + "api:app", + host="0.0.0.0", + port=8000, + log_level=LOG_LEVEL.lower(), + access_log=True + )