Create complete Whisper service API implementation

This commit is contained in:
2025-07-14 00:29:40 -05:00
parent 6f98226dbb
commit 6714d35ffb

View File

@ -0,0 +1,416 @@
import asyncio
import json
import logging
import os
import time
from pathlib import Path
from typing import Dict, List, Optional, Any
import uvicorn
from fastapi import FastAPI, HTTPException, UploadFile, File
from pydantic import BaseModel
import redis.asyncio as redis
import psycopg2
from psycopg2.extras import DictCursor
import structlog
from faster_whisper import WhisperModel
import torch
# Configure structured logging
structlog.configure(
processors=[
structlog.stdlib.filter_by_level,
structlog.stdlib.add_log_level,
structlog.stdlib.add_logger_name,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.JSONRenderer()
],
context_class=dict,
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.stdlib.BoundLogger,
cache_logger_on_first_use=True,
)
logger = structlog.get_logger()
# Configuration
POSTGRES_URL = os.getenv('POSTGRES_URL')
REDIS_URL = os.getenv('REDIS_URL', 'redis://redis:6379')
WHISPER_MODEL = os.getenv('WHISPER_MODEL', 'large-v2')
MODEL_CACHE_DIR = '/app/models'
LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO')
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
COMPUTE_TYPE = 'float16' if DEVICE == 'cuda' else 'int8'
# Set up logging level
logging.basicConfig(level=getattr(logging, LOG_LEVEL))
# Pydantic models
class TranscriptionRequest(BaseModel):
audio_path: str
recording_id: int
language: Optional[str] = None
class TranscriptionResult(BaseModel):
text: str
language: str
language_confidence: float
segments: List[Dict[str, Any]]
processing_time_ms: int
model_used: str
class HealthResponse(BaseModel):
status: str
device: str
model: str
gpu_available: bool
gpu_memory_mb: Optional[int] = None
app = FastAPI(title="Whisper Transcription Service", version="2.0.0")
class WhisperService:
def __init__(self):
self.model = None
self.redis_client = None
self.pg_connection = None
self.model_loading = False
async def initialize(self):
"""Initialize connections and load model"""
try:
# Connect to Redis
self.redis_client = redis.from_url(REDIS_URL)
await self.redis_client.ping()
# Connect to PostgreSQL
self.pg_connection = psycopg2.connect(POSTGRES_URL)
# Load Whisper model
await self.load_model()
logger.info("Whisper service initialized successfully",
device=DEVICE, model=WHISPER_MODEL)
except Exception as e:
logger.error("Failed to initialize Whisper service", error=str(e))
raise
async def load_model(self):
"""Load the Whisper model"""
if self.model_loading:
logger.warning("Model is already loading")
return
self.model_loading = True
try:
logger.info("Loading Whisper model", model=WHISPER_MODEL, device=DEVICE)
# Create model cache directory
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
# Load model with optimizations
self.model = WhisperModel(
WHISPER_MODEL,
device=DEVICE,
compute_type=COMPUTE_TYPE,
download_root=MODEL_CACHE_DIR,
local_files_only=False # Allow downloading if not cached
)
# Log GPU memory usage if available
if DEVICE == 'cuda' and torch.cuda.is_available():
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**2
used_memory = torch.cuda.memory_allocated(0) / 1024**2
logger.info("GPU memory status",
total_mb=gpu_memory, used_mb=used_memory)
logger.info("Whisper model loaded successfully")
except Exception as e:
logger.error("Failed to load Whisper model", error=str(e))
raise
finally:
self.model_loading = False
async def cleanup(self):
"""Clean up connections"""
if self.redis_client:
await self.redis_client.close()
if self.pg_connection:
self.pg_connection.close()
async def transcribe_audio(self, audio_path: str, recording_id: int,
language: Optional[str] = None) -> TranscriptionResult:
"""Transcribe audio file"""
if not self.model:
raise HTTPException(status_code=503, detail="Model not loaded")
if not os.path.exists(audio_path):
raise HTTPException(status_code=404, detail="Audio file not found")
start_time = time.time()
try:
logger.info("Starting transcription",
recording_id=recording_id, audio_path=audio_path)
# Transcribe with faster-whisper
segments_generator, info = self.model.transcribe(
audio_path,
language=language,
beam_size=5,
best_of=5,
temperature=0.0,
condition_on_previous_text=False,
vad_filter=True,
vad_parameters=dict(min_silence_duration_ms=500)
)
# Convert segments to list and extract text
segments = []
full_text_parts = []
for segment in segments_generator:
segment_dict = {
'start': segment.start,
'end': segment.end,
'text': segment.text.strip(),
'words': [
{
'start': word.start,
'end': word.end,
'word': word.word,
'probability': word.probability
} for word in segment.words
] if hasattr(segment, 'words') and segment.words else []
}
segments.append(segment_dict)
full_text_parts.append(segment.text.strip())
# Combine all text
full_text = ' '.join(full_text_parts).strip()
processing_time_ms = int((time.time() - start_time) * 1000)
# Create result
result = TranscriptionResult(
text=full_text,
language=info.language,
language_confidence=info.language_probability,
segments=segments,
processing_time_ms=processing_time_ms,
model_used=WHISPER_MODEL
)
logger.info("Transcription completed",
recording_id=recording_id,
language=info.language,
confidence=info.language_probability,
text_length=len(full_text),
processing_time_ms=processing_time_ms)
return result
except Exception as e:
logger.error("Error during transcription",
error=str(e), recording_id=recording_id)
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
async def update_transcription_record(self, recording_id: int, result: TranscriptionResult):
"""Update transcription record in database"""
try:
with self.pg_connection.cursor() as cursor:
# Insert transcription record
cursor.execute("""
INSERT INTO transcriptions
(recording_id, detected_language, language_confidence, transcription_text,
word_count, processing_time_ms, whisper_model, status)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
RETURNING id, uuid
""", (
recording_id,
result.language,
result.language_confidence,
result.text,
len(result.text.split()) if result.text else 0,
result.processing_time_ms,
result.model_used,
'completed'
))
transcription_record = cursor.fetchone()
self.pg_connection.commit()
logger.info("Created transcription record",
transcription_id=transcription_record[0],
recording_id=recording_id)
return transcription_record[0] # Return transcription ID
except Exception as e:
logger.error("Failed to update transcription record",
error=str(e), recording_id=recording_id)
self.pg_connection.rollback()
raise
async def publish_event(self, event: str, data: Dict[str, Any]):
"""Publish event to Redis"""
try:
message = {
'event': event,
'data': data,
'timestamp': time.time(),
'service': 'whisper-service'
}
await self.redis_client.publish('voice-translator-events', json.dumps(message))
logger.debug("Published event", event=event, data=data)
except Exception as e:
logger.error("Failed to publish event", error=str(e), event=event)
async def process_audio_event(self, message_data: Dict[str, Any]):
"""Process audio_processed event"""
recording_id = message_data.get('recordingId')
processed_path = message_data.get('processedPath')
if not recording_id or not processed_path:
logger.error("Invalid audio_processed message", data=message_data)
return
try:
# Transcribe the audio
result = await self.transcribe_audio(processed_path, recording_id)
# Update database
transcription_id = await self.update_transcription_record(recording_id, result)
# Publish transcription completed event
await self.publish_event('transcription_completed', {
'recordingId': recording_id,
'transcriptionId': transcription_id,
'text': result.text,
'language': result.language,
'languageConfidence': result.language_confidence,
'processingTimeMs': result.processing_time_ms
})
except Exception as e:
logger.error("Error processing audio event",
error=str(e), recording_id=recording_id)
async def listen_for_events(self):
"""Listen for audio processing events"""
logger.info("Starting to listen for audio events")
pubsub = self.redis_client.pubsub()
await pubsub.subscribe('voice-translator-events')
try:
async for message in pubsub.listen():
if message['type'] == 'message':
try:
data = json.loads(message['data'])
# Process audio_processed events
if data.get('event') == 'audio_processed':
await self.process_audio_event(data['data'])
except json.JSONDecodeError:
logger.error("Invalid JSON in message", data=message['data'])
except Exception as e:
logger.error("Error processing message", error=str(e))
except Exception as e:
logger.error("Error in message listener", error=str(e))
finally:
await pubsub.unsubscribe('voice-translator-events')
# Global service instance
whisper_service = WhisperService()
@app.on_event("startup")
async def startup_event():
"""Initialize the service on startup"""
await whisper_service.initialize()
# Start background task to listen for events
asyncio.create_task(whisper_service.listen_for_events())
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on shutdown"""
await whisper_service.cleanup()
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint"""
gpu_memory = None
if torch.cuda.is_available():
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**2
return HealthResponse(
status="healthy" if whisper_service.model else "loading",
device=DEVICE,
model=WHISPER_MODEL,
gpu_available=torch.cuda.is_available(),
gpu_memory_mb=int(gpu_memory) if gpu_memory else None
)
@app.post("/transcribe", response_model=TranscriptionResult)
async def transcribe_endpoint(request: TranscriptionRequest):
"""Manual transcription endpoint"""
return await whisper_service.transcribe_audio(
request.audio_path,
request.recording_id,
request.language
)
@app.post("/transcribe/upload")
async def transcribe_upload(file: UploadFile = File(...), recording_id: int = 0):
"""Upload and transcribe audio file"""
if not file.filename.endswith(('.wav', '.mp3', '.m4a', '.flac')):
raise HTTPException(status_code=400, detail="Unsupported audio format")
# Save uploaded file temporarily
temp_path = f"/tmp/{file.filename}"
try:
with open(temp_path, "wb") as f:
content = await file.read()
f.write(content)
# Transcribe
result = await whisper_service.transcribe_audio(temp_path, recording_id)
return result
finally:
# Clean up temp file
if os.path.exists(temp_path):
os.remove(temp_path)
@app.get("/models")
async def list_models():
"""List available Whisper models"""
return {
"current_model": WHISPER_MODEL,
"device": DEVICE,
"compute_type": COMPUTE_TYPE,
"available_models": [
"tiny", "tiny.en", "base", "base.en",
"small", "small.en", "medium", "medium.en",
"large", "large-v1", "large-v2", "large-v3"
]
}
if __name__ == "__main__":
uvicorn.run(
"api:app",
host="0.0.0.0",
port=8000,
log_level=LOG_LEVEL.lower(),
access_log=True
)