ai-lab-transformers-playground/src/api/app.py

517 lines
18 KiB
Python

"""
FastAPI application for AI Lab
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import Dict, Any
import logging
import torch
from .models import (
TextRequest, TextListRequest, QARequest, FillMaskRequest, TextGenRequest,
SentimentResponse, NERResponse, QAResponse, FillMaskResponse,
ModerationResponse, TextGenResponse, BatchResponse
)
# Global pipeline instances
pipelines: Dict[str, Any] = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage application lifespan - load models on startup"""
global pipelines
# Load all pipelines on startup
try:
logging.info("Loading AI pipelines...")
# Import here to avoid circular imports
from src.pipelines.sentiment import SentimentAnalyzer
from src.pipelines.ner import NamedEntityRecognizer
from src.pipelines.qa import QuestionAnsweringSystem
from src.pipelines.fillmask import FillMaskAnalyzer
from src.pipelines.moderation import ContentModerator
from src.pipelines.textgen import TextGenerator
pipelines["sentiment"] = SentimentAnalyzer()
pipelines["ner"] = NamedEntityRecognizer()
pipelines["qa"] = QuestionAnsweringSystem()
pipelines["fillmask"] = FillMaskAnalyzer()
pipelines["moderation"] = ContentModerator()
pipelines["textgen"] = TextGenerator()
logging.info("All pipelines loaded successfully!")
except Exception as e:
logging.error(f"Error loading pipelines: {e}")
# Don't raise, just log - allows API to start without all pipelines
yield
# Cleanup on shutdown
pipelines.clear()
logging.info("Pipelines cleaned up")
# Create FastAPI app
app = FastAPI(
title="AI Lab API",
description="API for various AI/ML pipelines using transformers",
version="1.0.0",
lifespan=lifespan,
swagger_ui_parameters={
"syntaxHighlight.theme": "obsidian",
"tryItOutEnabled": True,
"requestSnippetsEnabled": True,
"persistAuthorization": True,
"displayRequestDuration": True,
"defaultModelRendering": "model"
}
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Configure appropriately for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
"""Root endpoint"""
return {
"message": "Welcome to AI Lab API",
"version": "1.0.0",
"available_endpoints": [
"/sentiment",
"/ner",
"/qa",
"/fillmask",
"/moderation",
"/textgen",
"/sentiment/batch",
"/ner/batch",
"/fillmask/batch",
"/moderation/batch",
"/textgen/batch",
"/health",
"/docs"
]
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"pipelines_loaded": len(pipelines),
"available_pipelines": list(pipelines.keys())
}
@app.post("/sentiment", response_model=SentimentResponse)
async def analyze_sentiment(request: TextRequest):
"""Analyze sentiment of a text"""
try:
if "sentiment" not in pipelines:
raise HTTPException(status_code=503, detail="Sentiment pipeline not available")
# Use custom model if provided
if request.model_name:
from src.pipelines.sentiment import SentimentAnalyzer
analyzer = SentimentAnalyzer(request.model_name)
result = analyzer.analyze(request.text)
else:
result = pipelines["sentiment"].analyze(request.text)
if "error" in result:
return SentimentResponse(success=False, text=request.text, message=result["error"])
return SentimentResponse(
success=True,
text=result["text"],
sentiment=result["sentiment"],
confidence=result["confidence"]
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/ner", response_model=NERResponse)
async def extract_entities(request: TextRequest):
"""Extract named entities from text"""
try:
logging.info(f"NER request for text: {request.text[:50]}...")
if "ner" not in pipelines:
raise HTTPException(status_code=503, detail="NER pipeline not available")
try:
if request.model_name:
from src.pipelines.ner import NamedEntityRecognizer
ner = NamedEntityRecognizer(request.model_name)
result = ner.recognize(request.text)
else:
result = pipelines["ner"].recognize(request.text)
except Exception as pipeline_error:
logging.error(f"Pipeline error: {str(pipeline_error)}")
return NERResponse(success=False, text=request.text, message=f"Pipeline error: {str(pipeline_error)}")
logging.info(f"NER result keys: {list(result.keys())}")
if "error" in result:
logging.error(f"NER error: {result['error']}")
return NERResponse(success=False, text=request.text, message=result["error"])
# Validate result structure
if "original_text" not in result:
logging.error(f"Missing 'original_text' in result: {result}")
return NERResponse(success=False, text=request.text, message="Invalid NER result format")
if "entities" not in result:
logging.error(f"Missing 'entities' in result: {result}")
return NERResponse(success=False, text=request.text, message="Invalid NER result format")
return NERResponse(
success=True,
text=result["original_text"],
entities=result["entities"]
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/qa", response_model=QAResponse)
async def answer_question(request: QARequest):
"""Answer a question based on context"""
try:
if "qa" not in pipelines:
raise HTTPException(status_code=503, detail="QA pipeline not available")
if request.model_name:
from src.pipelines.qa import QuestionAnsweringSystem
qa = QuestionAnsweringSystem(request.model_name)
result = qa.answer(request.question, request.context)
else:
result = pipelines["qa"].answer(request.question, request.context)
if "error" in result:
return QAResponse(
success=False,
question=request.question,
context=request.context,
message=result["error"]
)
return QAResponse(
success=True,
question=result["question"],
context=result["context"],
answer=result["answer"],
confidence=result["confidence"]
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/fillmask", response_model=FillMaskResponse)
async def fill_mask(request: FillMaskRequest):
"""Fill masked words in text"""
try:
if "fillmask" not in pipelines:
raise HTTPException(status_code=503, detail="Fill-mask pipeline not available")
if request.model_name:
from src.pipelines.fillmask import FillMaskAnalyzer
fillmask = FillMaskAnalyzer(request.model_name)
result = fillmask.predict(request.text)
else:
result = pipelines["fillmask"].predict(request.text)
if "error" in result:
return FillMaskResponse(success=False, text=request.text, message=result["error"])
return FillMaskResponse(
success=True,
text=result["original_text"],
predictions=result["predictions"]
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/moderation", response_model=ModerationResponse)
async def moderate_content(request: TextRequest):
"""Moderate content for inappropriate material"""
try:
if "moderation" not in pipelines:
raise HTTPException(status_code=503, detail="Moderation pipeline not available")
if request.model_name:
from src.pipelines.moderation import ContentModerator
moderation = ContentModerator(request.model_name)
result = moderation.moderate(request.text)
else:
result = pipelines["moderation"].moderate(request.text)
if "error" in result:
return ModerationResponse(success=False, text=request.text, message=result["error"])
# Map the result fields correctly
flagged = result.get("is_modified", False) or result.get("toxic_score", 0.0) > 0.5
categories = {
"toxic_score": result.get("toxic_score", 0.0),
"is_modified": result.get("is_modified", False),
"restored_text": result.get("moderated_text", request.text),
"words_replaced": result.get("words_replaced", 0)
}
return ModerationResponse(
success=True,
text=result["original_text"],
flagged=flagged,
categories=categories
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/textgen", response_model=TextGenResponse)
async def generate_text(request: TextGenRequest):
"""Generate text from a prompt with configurable parameters"""
try:
if "textgen" not in pipelines:
raise HTTPException(status_code=503, detail="Text generation pipeline not available")
logging.info(f"Generating text for prompt: {request.text[:50]}...")
# Extract generation parameters
gen_params = {
"system_prompt": request.system_prompt,
"max_new_tokens": request.max_new_tokens,
"num_return_sequences": request.num_return_sequences,
"temperature": request.temperature,
"do_sample": request.do_sample
}
if request.model_name:
try:
from src.pipelines.textgen import TextGenerator
textgen = TextGenerator(request.model_name)
result = textgen.generate(request.text, **gen_params)
except torch.cuda.OutOfMemoryError:
logging.warning(f"CUDA OOM with model {request.model_name}, trying with default model on CPU")
result = pipelines["textgen"].generate(request.text, **gen_params)
except Exception as model_error:
logging.error(f"Error with custom model {request.model_name}: {model_error}")
return TextGenResponse(
success=False,
prompt=request.text,
message=f"Erreur avec le modèle {request.model_name}: {str(model_error)}. Essayez avec le modèle par défaut."
)
else:
result = pipelines["textgen"].generate(request.text, **gen_params)
logging.info(f"Generation result keys: {list(result.keys())}")
if "error" in result:
logging.error(f"Generation error: {result['error']}")
return TextGenResponse(success=False, prompt=request.text, message=result["error"])
# Extract the generated text from the first generation
generated_text = ""
if "generations" in result and len(result["generations"]) > 0:
# Get the continuation (text after the prompt) from the first generation
generated_text = result["generations"][0].get("continuation", "")
logging.info(f"Extracted generated text: {generated_text[:100]}...")
else:
logging.warning("No generations found in result")
return TextGenResponse(
success=True,
prompt=result["prompt"],
generated_text=generated_text,
parameters=result.get("parameters", {}),
generations=result.get("generations", [])
)
except Exception as e:
logging.error(f"TextGen endpoint error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# Batch processing endpoints
@app.post("/sentiment/batch", response_model=BatchResponse)
async def analyze_sentiment_batch(request: TextListRequest):
"""Analyze sentiment for multiple texts"""
try:
if "sentiment" not in pipelines:
raise HTTPException(status_code=503, detail="Sentiment pipeline not available")
analyzer = pipelines["sentiment"]
if request.model_name:
from src.pipelines.sentiment import SentimentAnalyzer
analyzer = SentimentAnalyzer(request.model_name)
results = []
failed_count = 0
for text in request.texts:
try:
result = analyzer.analyze(text)
if "error" in result:
failed_count += 1
results.append(result)
except Exception as e:
failed_count += 1
results.append({"text": text, "error": str(e)})
return BatchResponse(
success=True,
results=results,
processed_count=len(request.texts),
failed_count=failed_count
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/ner/batch", response_model=BatchResponse)
async def extract_entities_batch(request: TextListRequest):
"""Extract entities from multiple texts"""
try:
if "ner" not in pipelines:
raise HTTPException(status_code=503, detail="NER pipeline not available")
ner = pipelines["ner"]
if request.model_name:
from src.pipelines.ner import NamedEntityRecognizer
ner = NamedEntityRecognizer(request.model_name)
results = []
failed_count = 0
for text in request.texts:
try:
result = ner.recognize(text)
if "error" in result:
failed_count += 1
results.append(result)
except Exception as e:
failed_count += 1
results.append({"text": text, "error": str(e)})
return BatchResponse(
success=True,
results=results,
processed_count=len(request.texts),
failed_count=failed_count
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/fillmask/batch", response_model=BatchResponse)
async def fill_mask_batch(request: TextListRequest):
"""Fill masked words in multiple texts"""
try:
if "fillmask" not in pipelines:
raise HTTPException(status_code=503, detail="Fill-mask pipeline not available")
fillmask = pipelines["fillmask"]
if request.model_name:
from src.pipelines.fillmask import FillMaskAnalyzer
fillmask = FillMaskAnalyzer(request.model_name)
results = []
failed_count = 0
for text in request.texts:
try:
result = fillmask.predict(text)
if "error" in result:
failed_count += 1
results.append(result)
except Exception as e:
failed_count += 1
results.append({"text": text, "error": str(e)})
return BatchResponse(
success=True,
results=results,
processed_count=len(request.texts),
failed_count=failed_count
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/moderation/batch", response_model=BatchResponse)
async def moderate_content_batch(request: TextListRequest):
"""Moderate multiple texts for inappropriate content"""
try:
if "moderation" not in pipelines:
raise HTTPException(status_code=503, detail="Moderation pipeline not available")
moderation = pipelines["moderation"]
if request.model_name:
from src.pipelines.moderation import ContentModerator
moderation = ContentModerator(request.model_name)
results = []
failed_count = 0
for text in request.texts:
try:
result = moderation.moderate(text)
if "error" in result:
failed_count += 1
results.append(result)
except Exception as e:
failed_count += 1
results.append({"text": text, "error": str(e)})
return BatchResponse(
success=True,
results=results,
processed_count=len(request.texts),
failed_count=failed_count
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/textgen/batch", response_model=BatchResponse)
async def generate_text_batch(request: TextListRequest):
"""Generate text from multiple prompts"""
try:
if "textgen" not in pipelines:
raise HTTPException(status_code=503, detail="Text generation pipeline not available")
textgen = pipelines["textgen"]
if request.model_name:
from src.pipelines.textgen import TextGenerator
textgen = TextGenerator(request.model_name)
results = []
failed_count = 0
for text in request.texts:
try:
result = textgen.generate(text)
if "error" in result:
failed_count += 1
results.append(result)
except Exception as e:
failed_count += 1
results.append({"text": text, "error": str(e)})
return BatchResponse(
success=True,
results=results,
processed_count=len(request.texts),
failed_count=failed_count
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))