Create complete Whisper service API implementation
This commit is contained in:
416
services/whisper-service/src/api.py
Normal file
416
services/whisper-service/src/api.py
Normal file
@ -0,0 +1,416 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File
|
||||
from pydantic import BaseModel
|
||||
import redis.asyncio as redis
|
||||
import psycopg2
|
||||
from psycopg2.extras import DictCursor
|
||||
import structlog
|
||||
from faster_whisper import WhisperModel
|
||||
import torch
|
||||
|
||||
# Configure structured logging
|
||||
structlog.configure(
|
||||
processors=[
|
||||
structlog.stdlib.filter_by_level,
|
||||
structlog.stdlib.add_log_level,
|
||||
structlog.stdlib.add_logger_name,
|
||||
structlog.stdlib.PositionalArgumentsFormatter(),
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.processors.JSONRenderer()
|
||||
],
|
||||
context_class=dict,
|
||||
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||
wrapper_class=structlog.stdlib.BoundLogger,
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Configuration
|
||||
POSTGRES_URL = os.getenv('POSTGRES_URL')
|
||||
REDIS_URL = os.getenv('REDIS_URL', 'redis://redis:6379')
|
||||
WHISPER_MODEL = os.getenv('WHISPER_MODEL', 'large-v2')
|
||||
MODEL_CACHE_DIR = '/app/models'
|
||||
LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO')
|
||||
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
COMPUTE_TYPE = 'float16' if DEVICE == 'cuda' else 'int8'
|
||||
|
||||
# Set up logging level
|
||||
logging.basicConfig(level=getattr(logging, LOG_LEVEL))
|
||||
|
||||
# Pydantic models
|
||||
class TranscriptionRequest(BaseModel):
|
||||
audio_path: str
|
||||
recording_id: int
|
||||
language: Optional[str] = None
|
||||
|
||||
class TranscriptionResult(BaseModel):
|
||||
text: str
|
||||
language: str
|
||||
language_confidence: float
|
||||
segments: List[Dict[str, Any]]
|
||||
processing_time_ms: int
|
||||
model_used: str
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
device: str
|
||||
model: str
|
||||
gpu_available: bool
|
||||
gpu_memory_mb: Optional[int] = None
|
||||
|
||||
app = FastAPI(title="Whisper Transcription Service", version="2.0.0")
|
||||
|
||||
class WhisperService:
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.redis_client = None
|
||||
self.pg_connection = None
|
||||
self.model_loading = False
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize connections and load model"""
|
||||
try:
|
||||
# Connect to Redis
|
||||
self.redis_client = redis.from_url(REDIS_URL)
|
||||
await self.redis_client.ping()
|
||||
|
||||
# Connect to PostgreSQL
|
||||
self.pg_connection = psycopg2.connect(POSTGRES_URL)
|
||||
|
||||
# Load Whisper model
|
||||
await self.load_model()
|
||||
|
||||
logger.info("Whisper service initialized successfully",
|
||||
device=DEVICE, model=WHISPER_MODEL)
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize Whisper service", error=str(e))
|
||||
raise
|
||||
|
||||
async def load_model(self):
|
||||
"""Load the Whisper model"""
|
||||
if self.model_loading:
|
||||
logger.warning("Model is already loading")
|
||||
return
|
||||
|
||||
self.model_loading = True
|
||||
try:
|
||||
logger.info("Loading Whisper model", model=WHISPER_MODEL, device=DEVICE)
|
||||
|
||||
# Create model cache directory
|
||||
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
|
||||
|
||||
# Load model with optimizations
|
||||
self.model = WhisperModel(
|
||||
WHISPER_MODEL,
|
||||
device=DEVICE,
|
||||
compute_type=COMPUTE_TYPE,
|
||||
download_root=MODEL_CACHE_DIR,
|
||||
local_files_only=False # Allow downloading if not cached
|
||||
)
|
||||
|
||||
# Log GPU memory usage if available
|
||||
if DEVICE == 'cuda' and torch.cuda.is_available():
|
||||
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**2
|
||||
used_memory = torch.cuda.memory_allocated(0) / 1024**2
|
||||
logger.info("GPU memory status",
|
||||
total_mb=gpu_memory, used_mb=used_memory)
|
||||
|
||||
logger.info("Whisper model loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to load Whisper model", error=str(e))
|
||||
raise
|
||||
finally:
|
||||
self.model_loading = False
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up connections"""
|
||||
if self.redis_client:
|
||||
await self.redis_client.close()
|
||||
if self.pg_connection:
|
||||
self.pg_connection.close()
|
||||
|
||||
async def transcribe_audio(self, audio_path: str, recording_id: int,
|
||||
language: Optional[str] = None) -> TranscriptionResult:
|
||||
"""Transcribe audio file"""
|
||||
if not self.model:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
|
||||
if not os.path.exists(audio_path):
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
logger.info("Starting transcription",
|
||||
recording_id=recording_id, audio_path=audio_path)
|
||||
|
||||
# Transcribe with faster-whisper
|
||||
segments_generator, info = self.model.transcribe(
|
||||
audio_path,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
best_of=5,
|
||||
temperature=0.0,
|
||||
condition_on_previous_text=False,
|
||||
vad_filter=True,
|
||||
vad_parameters=dict(min_silence_duration_ms=500)
|
||||
)
|
||||
|
||||
# Convert segments to list and extract text
|
||||
segments = []
|
||||
full_text_parts = []
|
||||
|
||||
for segment in segments_generator:
|
||||
segment_dict = {
|
||||
'start': segment.start,
|
||||
'end': segment.end,
|
||||
'text': segment.text.strip(),
|
||||
'words': [
|
||||
{
|
||||
'start': word.start,
|
||||
'end': word.end,
|
||||
'word': word.word,
|
||||
'probability': word.probability
|
||||
} for word in segment.words
|
||||
] if hasattr(segment, 'words') and segment.words else []
|
||||
}
|
||||
segments.append(segment_dict)
|
||||
full_text_parts.append(segment.text.strip())
|
||||
|
||||
# Combine all text
|
||||
full_text = ' '.join(full_text_parts).strip()
|
||||
|
||||
processing_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Create result
|
||||
result = TranscriptionResult(
|
||||
text=full_text,
|
||||
language=info.language,
|
||||
language_confidence=info.language_probability,
|
||||
segments=segments,
|
||||
processing_time_ms=processing_time_ms,
|
||||
model_used=WHISPER_MODEL
|
||||
)
|
||||
|
||||
logger.info("Transcription completed",
|
||||
recording_id=recording_id,
|
||||
language=info.language,
|
||||
confidence=info.language_probability,
|
||||
text_length=len(full_text),
|
||||
processing_time_ms=processing_time_ms)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error during transcription",
|
||||
error=str(e), recording_id=recording_id)
|
||||
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
|
||||
|
||||
async def update_transcription_record(self, recording_id: int, result: TranscriptionResult):
|
||||
"""Update transcription record in database"""
|
||||
try:
|
||||
with self.pg_connection.cursor() as cursor:
|
||||
# Insert transcription record
|
||||
cursor.execute("""
|
||||
INSERT INTO transcriptions
|
||||
(recording_id, detected_language, language_confidence, transcription_text,
|
||||
word_count, processing_time_ms, whisper_model, status)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
||||
RETURNING id, uuid
|
||||
""", (
|
||||
recording_id,
|
||||
result.language,
|
||||
result.language_confidence,
|
||||
result.text,
|
||||
len(result.text.split()) if result.text else 0,
|
||||
result.processing_time_ms,
|
||||
result.model_used,
|
||||
'completed'
|
||||
))
|
||||
|
||||
transcription_record = cursor.fetchone()
|
||||
self.pg_connection.commit()
|
||||
|
||||
logger.info("Created transcription record",
|
||||
transcription_id=transcription_record[0],
|
||||
recording_id=recording_id)
|
||||
|
||||
return transcription_record[0] # Return transcription ID
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update transcription record",
|
||||
error=str(e), recording_id=recording_id)
|
||||
self.pg_connection.rollback()
|
||||
raise
|
||||
|
||||
async def publish_event(self, event: str, data: Dict[str, Any]):
|
||||
"""Publish event to Redis"""
|
||||
try:
|
||||
message = {
|
||||
'event': event,
|
||||
'data': data,
|
||||
'timestamp': time.time(),
|
||||
'service': 'whisper-service'
|
||||
}
|
||||
|
||||
await self.redis_client.publish('voice-translator-events', json.dumps(message))
|
||||
logger.debug("Published event", event=event, data=data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to publish event", error=str(e), event=event)
|
||||
|
||||
async def process_audio_event(self, message_data: Dict[str, Any]):
|
||||
"""Process audio_processed event"""
|
||||
recording_id = message_data.get('recordingId')
|
||||
processed_path = message_data.get('processedPath')
|
||||
|
||||
if not recording_id or not processed_path:
|
||||
logger.error("Invalid audio_processed message", data=message_data)
|
||||
return
|
||||
|
||||
try:
|
||||
# Transcribe the audio
|
||||
result = await self.transcribe_audio(processed_path, recording_id)
|
||||
|
||||
# Update database
|
||||
transcription_id = await self.update_transcription_record(recording_id, result)
|
||||
|
||||
# Publish transcription completed event
|
||||
await self.publish_event('transcription_completed', {
|
||||
'recordingId': recording_id,
|
||||
'transcriptionId': transcription_id,
|
||||
'text': result.text,
|
||||
'language': result.language,
|
||||
'languageConfidence': result.language_confidence,
|
||||
'processingTimeMs': result.processing_time_ms
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error processing audio event",
|
||||
error=str(e), recording_id=recording_id)
|
||||
|
||||
async def listen_for_events(self):
|
||||
"""Listen for audio processing events"""
|
||||
logger.info("Starting to listen for audio events")
|
||||
|
||||
pubsub = self.redis_client.pubsub()
|
||||
await pubsub.subscribe('voice-translator-events')
|
||||
|
||||
try:
|
||||
async for message in pubsub.listen():
|
||||
if message['type'] == 'message':
|
||||
try:
|
||||
data = json.loads(message['data'])
|
||||
|
||||
# Process audio_processed events
|
||||
if data.get('event') == 'audio_processed':
|
||||
await self.process_audio_event(data['data'])
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Invalid JSON in message", data=message['data'])
|
||||
except Exception as e:
|
||||
logger.error("Error processing message", error=str(e))
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in message listener", error=str(e))
|
||||
finally:
|
||||
await pubsub.unsubscribe('voice-translator-events')
|
||||
|
||||
# Global service instance
|
||||
whisper_service = WhisperService()
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Initialize the service on startup"""
|
||||
await whisper_service.initialize()
|
||||
|
||||
# Start background task to listen for events
|
||||
asyncio.create_task(whisper_service.listen_for_events())
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Cleanup on shutdown"""
|
||||
await whisper_service.cleanup()
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
gpu_memory = None
|
||||
if torch.cuda.is_available():
|
||||
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**2
|
||||
|
||||
return HealthResponse(
|
||||
status="healthy" if whisper_service.model else "loading",
|
||||
device=DEVICE,
|
||||
model=WHISPER_MODEL,
|
||||
gpu_available=torch.cuda.is_available(),
|
||||
gpu_memory_mb=int(gpu_memory) if gpu_memory else None
|
||||
)
|
||||
|
||||
@app.post("/transcribe", response_model=TranscriptionResult)
|
||||
async def transcribe_endpoint(request: TranscriptionRequest):
|
||||
"""Manual transcription endpoint"""
|
||||
return await whisper_service.transcribe_audio(
|
||||
request.audio_path,
|
||||
request.recording_id,
|
||||
request.language
|
||||
)
|
||||
|
||||
@app.post("/transcribe/upload")
|
||||
async def transcribe_upload(file: UploadFile = File(...), recording_id: int = 0):
|
||||
"""Upload and transcribe audio file"""
|
||||
if not file.filename.endswith(('.wav', '.mp3', '.m4a', '.flac')):
|
||||
raise HTTPException(status_code=400, detail="Unsupported audio format")
|
||||
|
||||
# Save uploaded file temporarily
|
||||
temp_path = f"/tmp/{file.filename}"
|
||||
|
||||
try:
|
||||
with open(temp_path, "wb") as f:
|
||||
content = await file.read()
|
||||
f.write(content)
|
||||
|
||||
# Transcribe
|
||||
result = await whisper_service.transcribe_audio(temp_path, recording_id)
|
||||
|
||||
return result
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
@app.get("/models")
|
||||
async def list_models():
|
||||
"""List available Whisper models"""
|
||||
return {
|
||||
"current_model": WHISPER_MODEL,
|
||||
"device": DEVICE,
|
||||
"compute_type": COMPUTE_TYPE,
|
||||
"available_models": [
|
||||
"tiny", "tiny.en", "base", "base.en",
|
||||
"small", "small.en", "medium", "medium.en",
|
||||
"large", "large-v1", "large-v2", "large-v3"
|
||||
]
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"api:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
log_level=LOG_LEVEL.lower(),
|
||||
access_log=True
|
||||
)
|
Reference in New Issue
Block a user