517 lines
18 KiB
Python
517 lines
18 KiB
Python
"""
|
|
FastAPI application for AI Lab
|
|
"""
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from contextlib import asynccontextmanager
|
|
from typing import Dict, Any
|
|
import logging
|
|
import torch
|
|
|
|
from .models import (
|
|
TextRequest, TextListRequest, QARequest, FillMaskRequest, TextGenRequest,
|
|
SentimentResponse, NERResponse, QAResponse, FillMaskResponse,
|
|
ModerationResponse, TextGenResponse, BatchResponse
|
|
)
|
|
|
|
# Global pipeline instances
|
|
pipelines: Dict[str, Any] = {}
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Manage application lifespan - load models on startup"""
|
|
global pipelines
|
|
|
|
# Load all pipelines on startup
|
|
try:
|
|
logging.info("Loading AI pipelines...")
|
|
|
|
# Import here to avoid circular imports
|
|
from src.pipelines.sentiment import SentimentAnalyzer
|
|
from src.pipelines.ner import NamedEntityRecognizer
|
|
from src.pipelines.qa import QuestionAnsweringSystem
|
|
from src.pipelines.fillmask import FillMaskAnalyzer
|
|
from src.pipelines.moderation import ContentModerator
|
|
from src.pipelines.textgen import TextGenerator
|
|
|
|
pipelines["sentiment"] = SentimentAnalyzer()
|
|
pipelines["ner"] = NamedEntityRecognizer()
|
|
pipelines["qa"] = QuestionAnsweringSystem()
|
|
pipelines["fillmask"] = FillMaskAnalyzer()
|
|
pipelines["moderation"] = ContentModerator()
|
|
pipelines["textgen"] = TextGenerator()
|
|
logging.info("All pipelines loaded successfully!")
|
|
except Exception as e:
|
|
logging.error(f"Error loading pipelines: {e}")
|
|
# Don't raise, just log - allows API to start without all pipelines
|
|
|
|
yield
|
|
|
|
# Cleanup on shutdown
|
|
pipelines.clear()
|
|
logging.info("Pipelines cleaned up")
|
|
|
|
|
|
# Create FastAPI app
|
|
app = FastAPI(
|
|
title="AI Lab API",
|
|
description="API for various AI/ML pipelines using transformers",
|
|
version="1.0.0",
|
|
lifespan=lifespan,
|
|
swagger_ui_parameters={
|
|
"syntaxHighlight.theme": "obsidian",
|
|
"tryItOutEnabled": True,
|
|
"requestSnippetsEnabled": True,
|
|
"persistAuthorization": True,
|
|
"displayRequestDuration": True,
|
|
"defaultModelRendering": "model"
|
|
}
|
|
)
|
|
|
|
# Add CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # Configure appropriately for production
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""Root endpoint"""
|
|
return {
|
|
"message": "Welcome to AI Lab API",
|
|
"version": "1.0.0",
|
|
"available_endpoints": [
|
|
"/sentiment",
|
|
"/ner",
|
|
"/qa",
|
|
"/fillmask",
|
|
"/moderation",
|
|
"/textgen",
|
|
"/sentiment/batch",
|
|
"/ner/batch",
|
|
"/fillmask/batch",
|
|
"/moderation/batch",
|
|
"/textgen/batch",
|
|
"/health",
|
|
"/docs"
|
|
]
|
|
}
|
|
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint"""
|
|
return {
|
|
"status": "healthy",
|
|
"pipelines_loaded": len(pipelines),
|
|
"available_pipelines": list(pipelines.keys())
|
|
}
|
|
|
|
|
|
@app.post("/sentiment", response_model=SentimentResponse)
|
|
async def analyze_sentiment(request: TextRequest):
|
|
"""Analyze sentiment of a text"""
|
|
try:
|
|
if "sentiment" not in pipelines:
|
|
raise HTTPException(status_code=503, detail="Sentiment pipeline not available")
|
|
|
|
# Use custom model if provided
|
|
if request.model_name:
|
|
from src.pipelines.sentiment import SentimentAnalyzer
|
|
analyzer = SentimentAnalyzer(request.model_name)
|
|
result = analyzer.analyze(request.text)
|
|
else:
|
|
result = pipelines["sentiment"].analyze(request.text)
|
|
|
|
if "error" in result:
|
|
return SentimentResponse(success=False, text=request.text, message=result["error"])
|
|
|
|
return SentimentResponse(
|
|
success=True,
|
|
text=result["text"],
|
|
sentiment=result["sentiment"],
|
|
confidence=result["confidence"]
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/ner", response_model=NERResponse)
|
|
async def extract_entities(request: TextRequest):
|
|
"""Extract named entities from text"""
|
|
try:
|
|
logging.info(f"NER request for text: {request.text[:50]}...")
|
|
if "ner" not in pipelines:
|
|
raise HTTPException(status_code=503, detail="NER pipeline not available")
|
|
|
|
try:
|
|
if request.model_name:
|
|
from src.pipelines.ner import NamedEntityRecognizer
|
|
ner = NamedEntityRecognizer(request.model_name)
|
|
result = ner.recognize(request.text)
|
|
else:
|
|
result = pipelines["ner"].recognize(request.text)
|
|
except Exception as pipeline_error:
|
|
logging.error(f"Pipeline error: {str(pipeline_error)}")
|
|
return NERResponse(success=False, text=request.text, message=f"Pipeline error: {str(pipeline_error)}")
|
|
|
|
logging.info(f"NER result keys: {list(result.keys())}")
|
|
|
|
if "error" in result:
|
|
logging.error(f"NER error: {result['error']}")
|
|
return NERResponse(success=False, text=request.text, message=result["error"])
|
|
|
|
# Validate result structure
|
|
if "original_text" not in result:
|
|
logging.error(f"Missing 'original_text' in result: {result}")
|
|
return NERResponse(success=False, text=request.text, message="Invalid NER result format")
|
|
|
|
if "entities" not in result:
|
|
logging.error(f"Missing 'entities' in result: {result}")
|
|
return NERResponse(success=False, text=request.text, message="Invalid NER result format")
|
|
|
|
return NERResponse(
|
|
success=True,
|
|
text=result["original_text"],
|
|
entities=result["entities"]
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/qa", response_model=QAResponse)
|
|
async def answer_question(request: QARequest):
|
|
"""Answer a question based on context"""
|
|
try:
|
|
if "qa" not in pipelines:
|
|
raise HTTPException(status_code=503, detail="QA pipeline not available")
|
|
|
|
if request.model_name:
|
|
from src.pipelines.qa import QuestionAnsweringSystem
|
|
qa = QuestionAnsweringSystem(request.model_name)
|
|
result = qa.answer(request.question, request.context)
|
|
else:
|
|
result = pipelines["qa"].answer(request.question, request.context)
|
|
|
|
if "error" in result:
|
|
return QAResponse(
|
|
success=False,
|
|
question=request.question,
|
|
context=request.context,
|
|
message=result["error"]
|
|
)
|
|
|
|
return QAResponse(
|
|
success=True,
|
|
question=result["question"],
|
|
context=result["context"],
|
|
answer=result["answer"],
|
|
confidence=result["confidence"]
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/fillmask", response_model=FillMaskResponse)
|
|
async def fill_mask(request: FillMaskRequest):
|
|
"""Fill masked words in text"""
|
|
try:
|
|
if "fillmask" not in pipelines:
|
|
raise HTTPException(status_code=503, detail="Fill-mask pipeline not available")
|
|
|
|
if request.model_name:
|
|
from src.pipelines.fillmask import FillMaskAnalyzer
|
|
fillmask = FillMaskAnalyzer(request.model_name)
|
|
result = fillmask.predict(request.text)
|
|
else:
|
|
result = pipelines["fillmask"].predict(request.text)
|
|
|
|
if "error" in result:
|
|
return FillMaskResponse(success=False, text=request.text, message=result["error"])
|
|
|
|
return FillMaskResponse(
|
|
success=True,
|
|
text=result["original_text"],
|
|
predictions=result["predictions"]
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/moderation", response_model=ModerationResponse)
|
|
async def moderate_content(request: TextRequest):
|
|
"""Moderate content for inappropriate material"""
|
|
try:
|
|
if "moderation" not in pipelines:
|
|
raise HTTPException(status_code=503, detail="Moderation pipeline not available")
|
|
|
|
if request.model_name:
|
|
from src.pipelines.moderation import ContentModerator
|
|
moderation = ContentModerator(request.model_name)
|
|
result = moderation.moderate(request.text)
|
|
else:
|
|
result = pipelines["moderation"].moderate(request.text)
|
|
|
|
if "error" in result:
|
|
return ModerationResponse(success=False, text=request.text, message=result["error"])
|
|
|
|
# Map the result fields correctly
|
|
flagged = result.get("is_modified", False) or result.get("toxic_score", 0.0) > 0.5
|
|
categories = {
|
|
"toxic_score": result.get("toxic_score", 0.0),
|
|
"is_modified": result.get("is_modified", False),
|
|
"restored_text": result.get("moderated_text", request.text),
|
|
"words_replaced": result.get("words_replaced", 0)
|
|
}
|
|
|
|
return ModerationResponse(
|
|
success=True,
|
|
text=result["original_text"],
|
|
flagged=flagged,
|
|
categories=categories
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/textgen", response_model=TextGenResponse)
|
|
async def generate_text(request: TextGenRequest):
|
|
"""Generate text from a prompt with configurable parameters"""
|
|
try:
|
|
if "textgen" not in pipelines:
|
|
raise HTTPException(status_code=503, detail="Text generation pipeline not available")
|
|
|
|
logging.info(f"Generating text for prompt: {request.text[:50]}...")
|
|
|
|
# Extract generation parameters
|
|
gen_params = {
|
|
"system_prompt": request.system_prompt,
|
|
"max_new_tokens": request.max_new_tokens,
|
|
"num_return_sequences": request.num_return_sequences,
|
|
"temperature": request.temperature,
|
|
"do_sample": request.do_sample
|
|
}
|
|
|
|
if request.model_name:
|
|
try:
|
|
from src.pipelines.textgen import TextGenerator
|
|
textgen = TextGenerator(request.model_name)
|
|
result = textgen.generate(request.text, **gen_params)
|
|
except torch.cuda.OutOfMemoryError:
|
|
logging.warning(f"CUDA OOM with model {request.model_name}, trying with default model on CPU")
|
|
result = pipelines["textgen"].generate(request.text, **gen_params)
|
|
except Exception as model_error:
|
|
logging.error(f"Error with custom model {request.model_name}: {model_error}")
|
|
return TextGenResponse(
|
|
success=False,
|
|
prompt=request.text,
|
|
message=f"Erreur avec le modèle {request.model_name}: {str(model_error)}. Essayez avec le modèle par défaut."
|
|
)
|
|
else:
|
|
result = pipelines["textgen"].generate(request.text, **gen_params)
|
|
|
|
logging.info(f"Generation result keys: {list(result.keys())}")
|
|
|
|
if "error" in result:
|
|
logging.error(f"Generation error: {result['error']}")
|
|
return TextGenResponse(success=False, prompt=request.text, message=result["error"])
|
|
|
|
# Extract the generated text from the first generation
|
|
generated_text = ""
|
|
if "generations" in result and len(result["generations"]) > 0:
|
|
# Get the continuation (text after the prompt) from the first generation
|
|
generated_text = result["generations"][0].get("continuation", "")
|
|
logging.info(f"Extracted generated text: {generated_text[:100]}...")
|
|
else:
|
|
logging.warning("No generations found in result")
|
|
|
|
return TextGenResponse(
|
|
success=True,
|
|
prompt=result["prompt"],
|
|
generated_text=generated_text,
|
|
parameters=result.get("parameters", {}),
|
|
generations=result.get("generations", [])
|
|
)
|
|
except Exception as e:
|
|
logging.error(f"TextGen endpoint error: {str(e)}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# Batch processing endpoints
|
|
@app.post("/sentiment/batch", response_model=BatchResponse)
|
|
async def analyze_sentiment_batch(request: TextListRequest):
|
|
"""Analyze sentiment for multiple texts"""
|
|
try:
|
|
if "sentiment" not in pipelines:
|
|
raise HTTPException(status_code=503, detail="Sentiment pipeline not available")
|
|
|
|
analyzer = pipelines["sentiment"]
|
|
if request.model_name:
|
|
from src.pipelines.sentiment import SentimentAnalyzer
|
|
analyzer = SentimentAnalyzer(request.model_name)
|
|
|
|
results = []
|
|
failed_count = 0
|
|
|
|
for text in request.texts:
|
|
try:
|
|
result = analyzer.analyze(text)
|
|
if "error" in result:
|
|
failed_count += 1
|
|
results.append(result)
|
|
except Exception as e:
|
|
failed_count += 1
|
|
results.append({"text": text, "error": str(e)})
|
|
|
|
return BatchResponse(
|
|
success=True,
|
|
results=results,
|
|
processed_count=len(request.texts),
|
|
failed_count=failed_count
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/ner/batch", response_model=BatchResponse)
|
|
async def extract_entities_batch(request: TextListRequest):
|
|
"""Extract entities from multiple texts"""
|
|
try:
|
|
if "ner" not in pipelines:
|
|
raise HTTPException(status_code=503, detail="NER pipeline not available")
|
|
|
|
ner = pipelines["ner"]
|
|
if request.model_name:
|
|
from src.pipelines.ner import NamedEntityRecognizer
|
|
ner = NamedEntityRecognizer(request.model_name)
|
|
|
|
results = []
|
|
failed_count = 0
|
|
|
|
for text in request.texts:
|
|
try:
|
|
result = ner.recognize(text)
|
|
if "error" in result:
|
|
failed_count += 1
|
|
results.append(result)
|
|
except Exception as e:
|
|
failed_count += 1
|
|
results.append({"text": text, "error": str(e)})
|
|
|
|
return BatchResponse(
|
|
success=True,
|
|
results=results,
|
|
processed_count=len(request.texts),
|
|
failed_count=failed_count
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/fillmask/batch", response_model=BatchResponse)
|
|
async def fill_mask_batch(request: TextListRequest):
|
|
"""Fill masked words in multiple texts"""
|
|
try:
|
|
if "fillmask" not in pipelines:
|
|
raise HTTPException(status_code=503, detail="Fill-mask pipeline not available")
|
|
|
|
fillmask = pipelines["fillmask"]
|
|
if request.model_name:
|
|
from src.pipelines.fillmask import FillMaskAnalyzer
|
|
fillmask = FillMaskAnalyzer(request.model_name)
|
|
|
|
results = []
|
|
failed_count = 0
|
|
|
|
for text in request.texts:
|
|
try:
|
|
result = fillmask.predict(text)
|
|
if "error" in result:
|
|
failed_count += 1
|
|
results.append(result)
|
|
except Exception as e:
|
|
failed_count += 1
|
|
results.append({"text": text, "error": str(e)})
|
|
|
|
return BatchResponse(
|
|
success=True,
|
|
results=results,
|
|
processed_count=len(request.texts),
|
|
failed_count=failed_count
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/moderation/batch", response_model=BatchResponse)
|
|
async def moderate_content_batch(request: TextListRequest):
|
|
"""Moderate multiple texts for inappropriate content"""
|
|
try:
|
|
if "moderation" not in pipelines:
|
|
raise HTTPException(status_code=503, detail="Moderation pipeline not available")
|
|
|
|
moderation = pipelines["moderation"]
|
|
if request.model_name:
|
|
from src.pipelines.moderation import ContentModerator
|
|
moderation = ContentModerator(request.model_name)
|
|
|
|
results = []
|
|
failed_count = 0
|
|
|
|
for text in request.texts:
|
|
try:
|
|
result = moderation.moderate(text)
|
|
if "error" in result:
|
|
failed_count += 1
|
|
results.append(result)
|
|
except Exception as e:
|
|
failed_count += 1
|
|
results.append({"text": text, "error": str(e)})
|
|
|
|
return BatchResponse(
|
|
success=True,
|
|
results=results,
|
|
processed_count=len(request.texts),
|
|
failed_count=failed_count
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/textgen/batch", response_model=BatchResponse)
|
|
async def generate_text_batch(request: TextListRequest):
|
|
"""Generate text from multiple prompts"""
|
|
try:
|
|
if "textgen" not in pipelines:
|
|
raise HTTPException(status_code=503, detail="Text generation pipeline not available")
|
|
|
|
textgen = pipelines["textgen"]
|
|
if request.model_name:
|
|
from src.pipelines.textgen import TextGenerator
|
|
textgen = TextGenerator(request.model_name)
|
|
|
|
results = []
|
|
failed_count = 0
|
|
|
|
for text in request.texts:
|
|
try:
|
|
result = textgen.generate(text)
|
|
if "error" in result:
|
|
failed_count += 1
|
|
results.append(result)
|
|
except Exception as e:
|
|
failed_count += 1
|
|
results.append({"text": text, "error": str(e)})
|
|
|
|
return BatchResponse(
|
|
success=True,
|
|
results=results,
|
|
processed_count=len(request.texts),
|
|
failed_count=failed_count
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e)) |