Create Translation service with local NLLB + Google Translate fallback
This commit is contained in:
566
services/translator/src/api.py
Normal file
566
services/translator/src/api.py
Normal file
@ -0,0 +1,566 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
from enum import Enum
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import redis.asyncio as redis
|
||||
import psycopg2
|
||||
from psycopg2.extras import DictCursor
|
||||
import structlog
|
||||
import torch
|
||||
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
|
||||
from googletrans import Translator as GoogleTranslator
|
||||
from deep_translator import GoogleTranslator as DeepGoogleTranslator
|
||||
import httpx
|
||||
from cachetools import TTLCache
|
||||
|
||||
# Configure structured logging
|
||||
structlog.configure(
|
||||
processors=[
|
||||
structlog.stdlib.filter_by_level,
|
||||
structlog.stdlib.add_log_level,
|
||||
structlog.stdlib.add_logger_name,
|
||||
structlog.stdlib.PositionalArgumentsFormatter(),
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.processors.JSONRenderer()
|
||||
],
|
||||
context_class=dict,
|
||||
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||
wrapper_class=structlog.stdlib.BoundLogger,
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Configuration
|
||||
POSTGRES_URL = os.getenv('POSTGRES_URL')
|
||||
REDIS_URL = os.getenv('REDIS_URL', 'redis://redis:6379')
|
||||
GOOGLE_TRANSLATE_API_KEY = os.getenv('GOOGLE_TRANSLATE_API_KEY')
|
||||
MODEL_CACHE_DIR = '/app/models'
|
||||
LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO')
|
||||
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
# Set up logging level
|
||||
logging.basicConfig(level=getattr(logging, LOG_LEVEL))
|
||||
|
||||
# Language mappings
|
||||
class Language(str, Enum):
|
||||
ENGLISH = "en"
|
||||
GERMAN = "de"
|
||||
KOREAN = "ko"
|
||||
|
||||
# NLLB language code mappings
|
||||
NLLB_LANG_MAP = {
|
||||
'en': 'eng_Latn',
|
||||
'de': 'deu_Latn',
|
||||
'ko': 'kor_Hang',
|
||||
'es': 'spa_Latn',
|
||||
'fr': 'fra_Latn',
|
||||
'it': 'ita_Latn',
|
||||
'pt': 'por_Latn',
|
||||
'ru': 'rus_Cyrl',
|
||||
'ja': 'jpn_Jpan',
|
||||
'zh': 'zho_Hans',
|
||||
'ar': 'arb_Arab'
|
||||
}
|
||||
|
||||
# Language display names and flags
|
||||
LANGUAGE_INFO = {
|
||||
'en': {'name': 'English', 'flag': '🇺🇸'},
|
||||
'de': {'name': 'German', 'flag': '🇩🇪'},
|
||||
'ko': {'name': 'Korean', 'flag': '🇰🇷'},
|
||||
'es': {'name': 'Spanish', 'flag': '🇪🇸'},
|
||||
'fr': {'name': 'French', 'flag': '🇫🇷'},
|
||||
'it': {'name': 'Italian', 'flag': '🇮🇹'},
|
||||
'pt': {'name': 'Portuguese', 'flag': '🇵🇹'},
|
||||
'ru': {'name': 'Russian', 'flag': '🇷🇺'},
|
||||
'ja': {'name': 'Japanese', 'flag': '🇯🇵'},
|
||||
'zh': {'name': 'Chinese', 'flag': '🇨🇳'},
|
||||
'ar': {'name': 'Arabic', 'flag': '🇸🇦'}
|
||||
}
|
||||
|
||||
# Pydantic models
|
||||
class TranslationRequest(BaseModel):
|
||||
text: str
|
||||
source_language: str
|
||||
target_language: str
|
||||
transcription_id: Optional[int] = None
|
||||
|
||||
class BulkTranslationRequest(BaseModel):
|
||||
text: str
|
||||
source_language: str
|
||||
target_languages: List[str]
|
||||
transcription_id: Optional[int] = None
|
||||
|
||||
class TranslationResult(BaseModel):
|
||||
translated_text: str
|
||||
source_language: str
|
||||
target_language: str
|
||||
confidence_score: Optional[float] = None
|
||||
processing_time_ms: int
|
||||
service_used: str
|
||||
|
||||
class BulkTranslationResult(BaseModel):
|
||||
source_text: str
|
||||
source_language: str
|
||||
translations: Dict[str, TranslationResult]
|
||||
total_processing_time_ms: int
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
local_model_loaded: bool
|
||||
device: str
|
||||
supported_languages: List[str]
|
||||
google_api_available: bool
|
||||
|
||||
app = FastAPI(title="Translation Service", version="2.0.0")
|
||||
|
||||
class TranslationService:
|
||||
def __init__(self):
|
||||
self.nllb_model = None
|
||||
self.nllb_tokenizer = None
|
||||
self.google_translator = None
|
||||
self.deep_translator = None
|
||||
self.redis_client = None
|
||||
self.pg_connection = None
|
||||
self.model_loading = False
|
||||
|
||||
# Translation cache (TTL: 1 hour)
|
||||
self.translation_cache = TTLCache(maxsize=1000, ttl=3600)
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize connections and load models"""
|
||||
try:
|
||||
# Connect to Redis
|
||||
self.redis_client = redis.from_url(REDIS_URL)
|
||||
await self.redis_client.ping()
|
||||
|
||||
# Connect to PostgreSQL
|
||||
self.pg_connection = psycopg2.connect(POSTGRES_URL)
|
||||
|
||||
# Initialize Google Translators
|
||||
self.google_translator = GoogleTranslator()
|
||||
self.deep_translator = DeepGoogleTranslator(api_key=GOOGLE_TRANSLATE_API_KEY) if GOOGLE_TRANSLATE_API_KEY else None
|
||||
|
||||
# Load local NLLB model
|
||||
await self.load_local_model()
|
||||
|
||||
logger.info("Translation service initialized successfully", device=DEVICE)
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize translation service", error=str(e))
|
||||
raise
|
||||
|
||||
async def load_local_model(self):
|
||||
"""Load the local NLLB model"""
|
||||
if self.model_loading:
|
||||
logger.warning("Model is already loading")
|
||||
return
|
||||
|
||||
self.model_loading = True
|
||||
try:
|
||||
logger.info("Loading NLLB translation model", device=DEVICE)
|
||||
|
||||
# Create model cache directory
|
||||
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
|
||||
|
||||
# Load NLLB-200 distilled model (600M parameters - good balance of speed/quality)
|
||||
model_name = "facebook/nllb-200-distilled-600M"
|
||||
|
||||
self.nllb_tokenizer = M2M100Tokenizer.from_pretrained(
|
||||
model_name,
|
||||
cache_dir=MODEL_CACHE_DIR
|
||||
)
|
||||
|
||||
self.nllb_model = M2M100ForConditionalGeneration.from_pretrained(
|
||||
model_name,
|
||||
cache_dir=MODEL_CACHE_DIR,
|
||||
torch_dtype=torch.float16 if DEVICE == 'cuda' else torch.float32
|
||||
).to(DEVICE)
|
||||
|
||||
if DEVICE == 'cuda':
|
||||
self.nllb_model.half() # Use half precision for faster inference
|
||||
|
||||
logger.info("NLLB model loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to load NLLB model", error=str(e))
|
||||
# Continue without local model - will use web APIs
|
||||
finally:
|
||||
self.model_loading = False
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up connections"""
|
||||
if self.redis_client:
|
||||
await self.redis_client.close()
|
||||
if self.pg_connection:
|
||||
self.pg_connection.close()
|
||||
|
||||
def get_cache_key(self, text: str, source_lang: str, target_lang: str) -> str:
|
||||
"""Generate cache key for translation"""
|
||||
return f"{source_lang}:{target_lang}:{hash(text)}"
|
||||
|
||||
async def translate_with_nllb(self, text: str, source_lang: str, target_lang: str) -> Optional[str]:
|
||||
"""Translate using local NLLB model"""
|
||||
if not self.nllb_model or not self.nllb_tokenizer:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Map language codes to NLLB format
|
||||
source_nllb = NLLB_LANG_MAP.get(source_lang)
|
||||
target_nllb = NLLB_LANG_MAP.get(target_lang)
|
||||
|
||||
if not source_nllb or not target_nllb:
|
||||
logger.warning("Language not supported by NLLB",
|
||||
source=source_lang, target=target_lang)
|
||||
return None
|
||||
|
||||
# Set source language
|
||||
self.nllb_tokenizer.src_lang = source_nllb
|
||||
|
||||
# Tokenize input text
|
||||
encoded = self.nllb_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
||||
|
||||
if DEVICE == 'cuda':
|
||||
encoded = {k: v.to(DEVICE) for k, v in encoded.items()}
|
||||
|
||||
# Generate translation
|
||||
generated_tokens = self.nllb_model.generate(
|
||||
**encoded,
|
||||
forced_bos_token_id=self.nllb_tokenizer.lang_code_to_id[target_nllb],
|
||||
max_length=512,
|
||||
num_beams=4,
|
||||
length_penalty=1.0,
|
||||
early_stopping=True
|
||||
)
|
||||
|
||||
# Decode translation
|
||||
translated_text = self.nllb_tokenizer.batch_decode(
|
||||
generated_tokens, skip_special_tokens=True
|
||||
)[0]
|
||||
|
||||
return translated_text.strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in NLLB translation",
|
||||
error=str(e), source=source_lang, target=target_lang)
|
||||
return None
|
||||
|
||||
async def translate_with_google(self, text: str, source_lang: str, target_lang: str) -> Optional[str]:
|
||||
"""Translate using Google Translate API"""
|
||||
try:
|
||||
# Try premium API first if available
|
||||
if self.deep_translator:
|
||||
try:
|
||||
result = self.deep_translator.translate(text, target=target_lang, source=source_lang)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning("Premium Google API failed, trying free version", error=str(e))
|
||||
|
||||
# Fallback to free Google Translate
|
||||
result = self.google_translator.translate(text, dest=target_lang, src=source_lang)
|
||||
return result.text
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in Google translation",
|
||||
error=str(e), source=source_lang, target=target_lang)
|
||||
return None
|
||||
|
||||
async def translate_text(self, text: str, source_lang: str, target_lang: str,
|
||||
transcription_id: Optional[int] = None) -> TranslationResult:
|
||||
"""Translate text with fallback strategy"""
|
||||
start_time = time.time()
|
||||
|
||||
# Check cache first
|
||||
cache_key = self.get_cache_key(text, source_lang, target_lang)
|
||||
if cache_key in self.translation_cache:
|
||||
cached_result = self.translation_cache[cache_key]
|
||||
logger.debug("Using cached translation", source=source_lang, target=target_lang)
|
||||
return TranslationResult(
|
||||
translated_text=cached_result,
|
||||
source_language=source_lang,
|
||||
target_language=target_lang,
|
||||
processing_time_ms=int((time.time() - start_time) * 1000),
|
||||
service_used="cache"
|
||||
)
|
||||
|
||||
translated_text = None
|
||||
service_used = "none"
|
||||
|
||||
# Try local NLLB model first
|
||||
if self.nllb_model:
|
||||
translated_text = await self.translate_with_nllb(text, source_lang, target_lang)
|
||||
if translated_text:
|
||||
service_used = "nllb_local"
|
||||
|
||||
# Fallback to Google Translate if local model failed
|
||||
if not translated_text:
|
||||
translated_text = await self.translate_with_google(text, source_lang, target_lang)
|
||||
if translated_text:
|
||||
service_used = "google_translate"
|
||||
|
||||
if not translated_text:
|
||||
raise HTTPException(status_code=500, detail="All translation services failed")
|
||||
|
||||
processing_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Cache the result
|
||||
self.translation_cache[cache_key] = translated_text
|
||||
|
||||
result = TranslationResult(
|
||||
translated_text=translated_text,
|
||||
source_language=source_lang,
|
||||
target_language=target_lang,
|
||||
processing_time_ms=processing_time_ms,
|
||||
service_used=service_used
|
||||
)
|
||||
|
||||
# Update database if transcription_id provided
|
||||
if transcription_id:
|
||||
await self.save_translation_record(transcription_id, result)
|
||||
|
||||
logger.info("Translation completed",
|
||||
source=source_lang, target=target_lang,
|
||||
service=service_used, time_ms=processing_time_ms)
|
||||
|
||||
return result
|
||||
|
||||
async def bulk_translate(self, text: str, source_lang: str, target_languages: List[str],
|
||||
transcription_id: Optional[int] = None) -> BulkTranslationResult:
|
||||
"""Translate to multiple target languages"""
|
||||
start_time = time.time()
|
||||
translations = {}
|
||||
|
||||
# Translate to each target language
|
||||
for target_lang in target_languages:
|
||||
if source_lang == target_lang:
|
||||
# Skip translation if source equals target
|
||||
translations[target_lang] = TranslationResult(
|
||||
translated_text=text,
|
||||
source_language=source_lang,
|
||||
target_language=target_lang,
|
||||
processing_time_ms=0,
|
||||
service_used="no_translation_needed"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
translation = await self.translate_text(text, source_lang, target_lang, transcription_id)
|
||||
translations[target_lang] = translation
|
||||
except Exception as e:
|
||||
logger.error("Failed to translate to language",
|
||||
target=target_lang, error=str(e))
|
||||
# Create error result
|
||||
translations[target_lang] = TranslationResult(
|
||||
translated_text=f"[Translation failed: {str(e)}]",
|
||||
source_language=source_lang,
|
||||
target_language=target_lang,
|
||||
processing_time_ms=0,
|
||||
service_used="error"
|
||||
)
|
||||
|
||||
total_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
return BulkTranslationResult(
|
||||
source_text=text,
|
||||
source_language=source_lang,
|
||||
translations=translations,
|
||||
total_processing_time_ms=total_time_ms
|
||||
)
|
||||
|
||||
async def save_translation_record(self, transcription_id: int, result: TranslationResult):
|
||||
"""Save translation to database"""
|
||||
try:
|
||||
with self.pg_connection.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
INSERT INTO translations
|
||||
(transcription_id, target_language, translated_text, translation_service,
|
||||
processing_time_ms, status)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)
|
||||
RETURNING id
|
||||
""", (
|
||||
transcription_id,
|
||||
result.target_language,
|
||||
result.translated_text,
|
||||
result.service_used,
|
||||
result.processing_time_ms,
|
||||
'completed'
|
||||
))
|
||||
|
||||
translation_id = cursor.fetchone()[0]
|
||||
self.pg_connection.commit()
|
||||
|
||||
logger.debug("Saved translation record",
|
||||
translation_id=translation_id, transcription_id=transcription_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to save translation record",
|
||||
error=str(e), transcription_id=transcription_id)
|
||||
self.pg_connection.rollback()
|
||||
|
||||
async def publish_event(self, event: str, data: Dict[str, Any]):
|
||||
"""Publish event to Redis"""
|
||||
try:
|
||||
message = {
|
||||
'event': event,
|
||||
'data': data,
|
||||
'timestamp': time.time(),
|
||||
'service': 'translator'
|
||||
}
|
||||
|
||||
await self.redis_client.publish('voice-translator-events', json.dumps(message))
|
||||
logger.debug("Published event", event=event, data=data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to publish event", error=str(e), event=event)
|
||||
|
||||
async def process_transcription_event(self, message_data: Dict[str, Any]):
|
||||
"""Process transcription_completed event"""
|
||||
transcription_id = message_data.get('transcriptionId')
|
||||
text = message_data.get('text')
|
||||
detected_language = message_data.get('language')
|
||||
|
||||
if not all([transcription_id, text, detected_language]):
|
||||
logger.error("Invalid transcription_completed message", data=message_data)
|
||||
return
|
||||
|
||||
try:
|
||||
# Define target languages (always English, German, Korean)
|
||||
target_languages = ['en', 'de', 'ko']
|
||||
|
||||
# Remove source language from targets (no need to translate to itself)
|
||||
if detected_language in target_languages:
|
||||
target_languages = [lang for lang in target_languages if lang != detected_language]
|
||||
|
||||
# Perform bulk translation
|
||||
if target_languages:
|
||||
bulk_result = await self.bulk_translate(
|
||||
text, detected_language, target_languages, transcription_id
|
||||
)
|
||||
|
||||
# Publish translation completed event
|
||||
await self.publish_event('translations_completed', {
|
||||
'transcriptionId': transcription_id,
|
||||
'sourceLanguage': detected_language,
|
||||
'sourceText': text,
|
||||
'translations': {
|
||||
lang: result.translated_text
|
||||
for lang, result in bulk_result.translations.items()
|
||||
},
|
||||
'processingTimeMs': bulk_result.total_processing_time_ms
|
||||
})
|
||||
|
||||
logger.info("Translations completed",
|
||||
transcription_id=transcription_id,
|
||||
target_languages=target_languages)
|
||||
else:
|
||||
# No translation needed, just publish the original
|
||||
await self.publish_event('translations_completed', {
|
||||
'transcriptionId': transcription_id,
|
||||
'sourceLanguage': detected_language,
|
||||
'sourceText': text,
|
||||
'translations': {},
|
||||
'processingTimeMs': 0
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error processing transcription event",
|
||||
error=str(e), transcription_id=transcription_id)
|
||||
|
||||
async def listen_for_events(self):
|
||||
"""Listen for transcription events"""
|
||||
logger.info("Starting to listen for transcription events")
|
||||
|
||||
pubsub = self.redis_client.pubsub()
|
||||
await pubsub.subscribe('voice-translator-events')
|
||||
|
||||
try:
|
||||
async for message in pubsub.listen():
|
||||
if message['type'] == 'message':
|
||||
try:
|
||||
data = json.loads(message['data'])
|
||||
|
||||
# Process transcription_completed events
|
||||
if data.get('event') == 'transcription_completed':
|
||||
await self.process_transcription_event(data['data'])
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Invalid JSON in message", data=message['data'])
|
||||
except Exception as e:
|
||||
logger.error("Error processing message", error=str(e))
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in message listener", error=str(e))
|
||||
finally:
|
||||
await pubsub.unsubscribe('voice-translator-events')
|
||||
|
||||
# Global service instance
|
||||
translation_service = TranslationService()
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Initialize the service on startup"""
|
||||
await translation_service.initialize()
|
||||
|
||||
# Start background task to listen for events
|
||||
asyncio.create_task(translation_service.listen_for_events())
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Cleanup on shutdown"""
|
||||
await translation_service.cleanup()
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return HealthResponse(
|
||||
status="healthy",
|
||||
local_model_loaded=translation_service.nllb_model is not None,
|
||||
device=DEVICE,
|
||||
supported_languages=list(NLLB_LANG_MAP.keys()),
|
||||
google_api_available=translation_service.deep_translator is not None
|
||||
)
|
||||
|
||||
@app.post("/translate", response_model=TranslationResult)
|
||||
async def translate_endpoint(request: TranslationRequest):
|
||||
"""Single translation endpoint"""
|
||||
return await translation_service.translate_text(
|
||||
request.text,
|
||||
request.source_language,
|
||||
request.target_language,
|
||||
request.transcription_id
|
||||
)
|
||||
|
||||
@app.post("/translate/bulk", response_model=BulkTranslationResult)
|
||||
async def bulk_translate_endpoint(request: BulkTranslationRequest):
|
||||
"""Bulk translation endpoint"""
|
||||
return await translation_service.bulk_translate(
|
||||
request.text,
|
||||
request.source_language,
|
||||
request.target_languages,
|
||||
request.transcription_id
|
||||
)
|
||||
|
||||
@app.get("/languages")
|
||||
async def list_languages():
|
||||
"""List supported languages"""
|
||||
return {
|
||||
"supported_languages": LANGUAGE_INFO,
|
||||
"local_model_languages": list(NLLB_LANG_MAP.keys()),
|
||||
"primary_languages": ["en", "de", "ko"]
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"api:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
log_level=LOG_LEVEL.lower(),
|
||||
access_log=True
|
||||
)
|
Reference in New Issue
Block a user