Create Translation service with local NLLB + Google Translate fallback

This commit is contained in:
2025-07-14 10:31:08 -05:00
parent c7550fdd81
commit dc5958a995

View File

@ -0,0 +1,566 @@
import asyncio
import json
import logging
import os
import time
from typing import Dict, List, Optional, Any
from enum import Enum
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import redis.asyncio as redis
import psycopg2
from psycopg2.extras import DictCursor
import structlog
import torch
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
from googletrans import Translator as GoogleTranslator
from deep_translator import GoogleTranslator as DeepGoogleTranslator
import httpx
from cachetools import TTLCache
# Configure structured logging
structlog.configure(
processors=[
structlog.stdlib.filter_by_level,
structlog.stdlib.add_log_level,
structlog.stdlib.add_logger_name,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.JSONRenderer()
],
context_class=dict,
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.stdlib.BoundLogger,
cache_logger_on_first_use=True,
)
logger = structlog.get_logger()
# Configuration
POSTGRES_URL = os.getenv('POSTGRES_URL')
REDIS_URL = os.getenv('REDIS_URL', 'redis://redis:6379')
GOOGLE_TRANSLATE_API_KEY = os.getenv('GOOGLE_TRANSLATE_API_KEY')
MODEL_CACHE_DIR = '/app/models'
LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO')
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# Set up logging level
logging.basicConfig(level=getattr(logging, LOG_LEVEL))
# Language mappings
class Language(str, Enum):
ENGLISH = "en"
GERMAN = "de"
KOREAN = "ko"
# NLLB language code mappings
NLLB_LANG_MAP = {
'en': 'eng_Latn',
'de': 'deu_Latn',
'ko': 'kor_Hang',
'es': 'spa_Latn',
'fr': 'fra_Latn',
'it': 'ita_Latn',
'pt': 'por_Latn',
'ru': 'rus_Cyrl',
'ja': 'jpn_Jpan',
'zh': 'zho_Hans',
'ar': 'arb_Arab'
}
# Language display names and flags
LANGUAGE_INFO = {
'en': {'name': 'English', 'flag': '🇺🇸'},
'de': {'name': 'German', 'flag': '🇩🇪'},
'ko': {'name': 'Korean', 'flag': '🇰🇷'},
'es': {'name': 'Spanish', 'flag': '🇪🇸'},
'fr': {'name': 'French', 'flag': '🇫🇷'},
'it': {'name': 'Italian', 'flag': '🇮🇹'},
'pt': {'name': 'Portuguese', 'flag': '🇵🇹'},
'ru': {'name': 'Russian', 'flag': '🇷🇺'},
'ja': {'name': 'Japanese', 'flag': '🇯🇵'},
'zh': {'name': 'Chinese', 'flag': '🇨🇳'},
'ar': {'name': 'Arabic', 'flag': '🇸🇦'}
}
# Pydantic models
class TranslationRequest(BaseModel):
text: str
source_language: str
target_language: str
transcription_id: Optional[int] = None
class BulkTranslationRequest(BaseModel):
text: str
source_language: str
target_languages: List[str]
transcription_id: Optional[int] = None
class TranslationResult(BaseModel):
translated_text: str
source_language: str
target_language: str
confidence_score: Optional[float] = None
processing_time_ms: int
service_used: str
class BulkTranslationResult(BaseModel):
source_text: str
source_language: str
translations: Dict[str, TranslationResult]
total_processing_time_ms: int
class HealthResponse(BaseModel):
status: str
local_model_loaded: bool
device: str
supported_languages: List[str]
google_api_available: bool
app = FastAPI(title="Translation Service", version="2.0.0")
class TranslationService:
def __init__(self):
self.nllb_model = None
self.nllb_tokenizer = None
self.google_translator = None
self.deep_translator = None
self.redis_client = None
self.pg_connection = None
self.model_loading = False
# Translation cache (TTL: 1 hour)
self.translation_cache = TTLCache(maxsize=1000, ttl=3600)
async def initialize(self):
"""Initialize connections and load models"""
try:
# Connect to Redis
self.redis_client = redis.from_url(REDIS_URL)
await self.redis_client.ping()
# Connect to PostgreSQL
self.pg_connection = psycopg2.connect(POSTGRES_URL)
# Initialize Google Translators
self.google_translator = GoogleTranslator()
self.deep_translator = DeepGoogleTranslator(api_key=GOOGLE_TRANSLATE_API_KEY) if GOOGLE_TRANSLATE_API_KEY else None
# Load local NLLB model
await self.load_local_model()
logger.info("Translation service initialized successfully", device=DEVICE)
except Exception as e:
logger.error("Failed to initialize translation service", error=str(e))
raise
async def load_local_model(self):
"""Load the local NLLB model"""
if self.model_loading:
logger.warning("Model is already loading")
return
self.model_loading = True
try:
logger.info("Loading NLLB translation model", device=DEVICE)
# Create model cache directory
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
# Load NLLB-200 distilled model (600M parameters - good balance of speed/quality)
model_name = "facebook/nllb-200-distilled-600M"
self.nllb_tokenizer = M2M100Tokenizer.from_pretrained(
model_name,
cache_dir=MODEL_CACHE_DIR
)
self.nllb_model = M2M100ForConditionalGeneration.from_pretrained(
model_name,
cache_dir=MODEL_CACHE_DIR,
torch_dtype=torch.float16 if DEVICE == 'cuda' else torch.float32
).to(DEVICE)
if DEVICE == 'cuda':
self.nllb_model.half() # Use half precision for faster inference
logger.info("NLLB model loaded successfully")
except Exception as e:
logger.error("Failed to load NLLB model", error=str(e))
# Continue without local model - will use web APIs
finally:
self.model_loading = False
async def cleanup(self):
"""Clean up connections"""
if self.redis_client:
await self.redis_client.close()
if self.pg_connection:
self.pg_connection.close()
def get_cache_key(self, text: str, source_lang: str, target_lang: str) -> str:
"""Generate cache key for translation"""
return f"{source_lang}:{target_lang}:{hash(text)}"
async def translate_with_nllb(self, text: str, source_lang: str, target_lang: str) -> Optional[str]:
"""Translate using local NLLB model"""
if not self.nllb_model or not self.nllb_tokenizer:
return None
try:
# Map language codes to NLLB format
source_nllb = NLLB_LANG_MAP.get(source_lang)
target_nllb = NLLB_LANG_MAP.get(target_lang)
if not source_nllb or not target_nllb:
logger.warning("Language not supported by NLLB",
source=source_lang, target=target_lang)
return None
# Set source language
self.nllb_tokenizer.src_lang = source_nllb
# Tokenize input text
encoded = self.nllb_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
if DEVICE == 'cuda':
encoded = {k: v.to(DEVICE) for k, v in encoded.items()}
# Generate translation
generated_tokens = self.nllb_model.generate(
**encoded,
forced_bos_token_id=self.nllb_tokenizer.lang_code_to_id[target_nllb],
max_length=512,
num_beams=4,
length_penalty=1.0,
early_stopping=True
)
# Decode translation
translated_text = self.nllb_tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)[0]
return translated_text.strip()
except Exception as e:
logger.error("Error in NLLB translation",
error=str(e), source=source_lang, target=target_lang)
return None
async def translate_with_google(self, text: str, source_lang: str, target_lang: str) -> Optional[str]:
"""Translate using Google Translate API"""
try:
# Try premium API first if available
if self.deep_translator:
try:
result = self.deep_translator.translate(text, target=target_lang, source=source_lang)
return result
except Exception as e:
logger.warning("Premium Google API failed, trying free version", error=str(e))
# Fallback to free Google Translate
result = self.google_translator.translate(text, dest=target_lang, src=source_lang)
return result.text
except Exception as e:
logger.error("Error in Google translation",
error=str(e), source=source_lang, target=target_lang)
return None
async def translate_text(self, text: str, source_lang: str, target_lang: str,
transcription_id: Optional[int] = None) -> TranslationResult:
"""Translate text with fallback strategy"""
start_time = time.time()
# Check cache first
cache_key = self.get_cache_key(text, source_lang, target_lang)
if cache_key in self.translation_cache:
cached_result = self.translation_cache[cache_key]
logger.debug("Using cached translation", source=source_lang, target=target_lang)
return TranslationResult(
translated_text=cached_result,
source_language=source_lang,
target_language=target_lang,
processing_time_ms=int((time.time() - start_time) * 1000),
service_used="cache"
)
translated_text = None
service_used = "none"
# Try local NLLB model first
if self.nllb_model:
translated_text = await self.translate_with_nllb(text, source_lang, target_lang)
if translated_text:
service_used = "nllb_local"
# Fallback to Google Translate if local model failed
if not translated_text:
translated_text = await self.translate_with_google(text, source_lang, target_lang)
if translated_text:
service_used = "google_translate"
if not translated_text:
raise HTTPException(status_code=500, detail="All translation services failed")
processing_time_ms = int((time.time() - start_time) * 1000)
# Cache the result
self.translation_cache[cache_key] = translated_text
result = TranslationResult(
translated_text=translated_text,
source_language=source_lang,
target_language=target_lang,
processing_time_ms=processing_time_ms,
service_used=service_used
)
# Update database if transcription_id provided
if transcription_id:
await self.save_translation_record(transcription_id, result)
logger.info("Translation completed",
source=source_lang, target=target_lang,
service=service_used, time_ms=processing_time_ms)
return result
async def bulk_translate(self, text: str, source_lang: str, target_languages: List[str],
transcription_id: Optional[int] = None) -> BulkTranslationResult:
"""Translate to multiple target languages"""
start_time = time.time()
translations = {}
# Translate to each target language
for target_lang in target_languages:
if source_lang == target_lang:
# Skip translation if source equals target
translations[target_lang] = TranslationResult(
translated_text=text,
source_language=source_lang,
target_language=target_lang,
processing_time_ms=0,
service_used="no_translation_needed"
)
else:
try:
translation = await self.translate_text(text, source_lang, target_lang, transcription_id)
translations[target_lang] = translation
except Exception as e:
logger.error("Failed to translate to language",
target=target_lang, error=str(e))
# Create error result
translations[target_lang] = TranslationResult(
translated_text=f"[Translation failed: {str(e)}]",
source_language=source_lang,
target_language=target_lang,
processing_time_ms=0,
service_used="error"
)
total_time_ms = int((time.time() - start_time) * 1000)
return BulkTranslationResult(
source_text=text,
source_language=source_lang,
translations=translations,
total_processing_time_ms=total_time_ms
)
async def save_translation_record(self, transcription_id: int, result: TranslationResult):
"""Save translation to database"""
try:
with self.pg_connection.cursor() as cursor:
cursor.execute("""
INSERT INTO translations
(transcription_id, target_language, translated_text, translation_service,
processing_time_ms, status)
VALUES (%s, %s, %s, %s, %s, %s)
RETURNING id
""", (
transcription_id,
result.target_language,
result.translated_text,
result.service_used,
result.processing_time_ms,
'completed'
))
translation_id = cursor.fetchone()[0]
self.pg_connection.commit()
logger.debug("Saved translation record",
translation_id=translation_id, transcription_id=transcription_id)
except Exception as e:
logger.error("Failed to save translation record",
error=str(e), transcription_id=transcription_id)
self.pg_connection.rollback()
async def publish_event(self, event: str, data: Dict[str, Any]):
"""Publish event to Redis"""
try:
message = {
'event': event,
'data': data,
'timestamp': time.time(),
'service': 'translator'
}
await self.redis_client.publish('voice-translator-events', json.dumps(message))
logger.debug("Published event", event=event, data=data)
except Exception as e:
logger.error("Failed to publish event", error=str(e), event=event)
async def process_transcription_event(self, message_data: Dict[str, Any]):
"""Process transcription_completed event"""
transcription_id = message_data.get('transcriptionId')
text = message_data.get('text')
detected_language = message_data.get('language')
if not all([transcription_id, text, detected_language]):
logger.error("Invalid transcription_completed message", data=message_data)
return
try:
# Define target languages (always English, German, Korean)
target_languages = ['en', 'de', 'ko']
# Remove source language from targets (no need to translate to itself)
if detected_language in target_languages:
target_languages = [lang for lang in target_languages if lang != detected_language]
# Perform bulk translation
if target_languages:
bulk_result = await self.bulk_translate(
text, detected_language, target_languages, transcription_id
)
# Publish translation completed event
await self.publish_event('translations_completed', {
'transcriptionId': transcription_id,
'sourceLanguage': detected_language,
'sourceText': text,
'translations': {
lang: result.translated_text
for lang, result in bulk_result.translations.items()
},
'processingTimeMs': bulk_result.total_processing_time_ms
})
logger.info("Translations completed",
transcription_id=transcription_id,
target_languages=target_languages)
else:
# No translation needed, just publish the original
await self.publish_event('translations_completed', {
'transcriptionId': transcription_id,
'sourceLanguage': detected_language,
'sourceText': text,
'translations': {},
'processingTimeMs': 0
})
except Exception as e:
logger.error("Error processing transcription event",
error=str(e), transcription_id=transcription_id)
async def listen_for_events(self):
"""Listen for transcription events"""
logger.info("Starting to listen for transcription events")
pubsub = self.redis_client.pubsub()
await pubsub.subscribe('voice-translator-events')
try:
async for message in pubsub.listen():
if message['type'] == 'message':
try:
data = json.loads(message['data'])
# Process transcription_completed events
if data.get('event') == 'transcription_completed':
await self.process_transcription_event(data['data'])
except json.JSONDecodeError:
logger.error("Invalid JSON in message", data=message['data'])
except Exception as e:
logger.error("Error processing message", error=str(e))
except Exception as e:
logger.error("Error in message listener", error=str(e))
finally:
await pubsub.unsubscribe('voice-translator-events')
# Global service instance
translation_service = TranslationService()
@app.on_event("startup")
async def startup_event():
"""Initialize the service on startup"""
await translation_service.initialize()
# Start background task to listen for events
asyncio.create_task(translation_service.listen_for_events())
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on shutdown"""
await translation_service.cleanup()
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint"""
return HealthResponse(
status="healthy",
local_model_loaded=translation_service.nllb_model is not None,
device=DEVICE,
supported_languages=list(NLLB_LANG_MAP.keys()),
google_api_available=translation_service.deep_translator is not None
)
@app.post("/translate", response_model=TranslationResult)
async def translate_endpoint(request: TranslationRequest):
"""Single translation endpoint"""
return await translation_service.translate_text(
request.text,
request.source_language,
request.target_language,
request.transcription_id
)
@app.post("/translate/bulk", response_model=BulkTranslationResult)
async def bulk_translate_endpoint(request: BulkTranslationRequest):
"""Bulk translation endpoint"""
return await translation_service.bulk_translate(
request.text,
request.source_language,
request.target_languages,
request.transcription_id
)
@app.get("/languages")
async def list_languages():
"""List supported languages"""
return {
"supported_languages": LANGUAGE_INFO,
"local_model_languages": list(NLLB_LANG_MAP.keys()),
"primary_languages": ["en", "de", "ko"]
}
if __name__ == "__main__":
uvicorn.run(
"api:app",
host="0.0.0.0",
port=8000,
log_level=LOG_LEVEL.lower(),
access_log=True
)