""" FastAPI application for AI Lab """ from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager from typing import Dict, Any import logging import torch from .models import ( TextRequest, TextListRequest, QARequest, FillMaskRequest, TextGenRequest, SentimentResponse, NERResponse, QAResponse, FillMaskResponse, ModerationResponse, TextGenResponse, BatchResponse ) # Global pipeline instances pipelines: Dict[str, Any] = {} @asynccontextmanager async def lifespan(app: FastAPI): """Manage application lifespan - load models on startup""" global pipelines # Load all pipelines on startup try: logging.info("Loading AI pipelines...") # Import here to avoid circular imports from src.pipelines.sentiment import SentimentAnalyzer from src.pipelines.ner import NamedEntityRecognizer from src.pipelines.qa import QuestionAnsweringSystem from src.pipelines.fillmask import FillMaskAnalyzer from src.pipelines.moderation import ContentModerator from src.pipelines.textgen import TextGenerator pipelines["sentiment"] = SentimentAnalyzer() pipelines["ner"] = NamedEntityRecognizer() pipelines["qa"] = QuestionAnsweringSystem() pipelines["fillmask"] = FillMaskAnalyzer() pipelines["moderation"] = ContentModerator() pipelines["textgen"] = TextGenerator() logging.info("All pipelines loaded successfully!") except Exception as e: logging.error(f"Error loading pipelines: {e}") # Don't raise, just log - allows API to start without all pipelines yield # Cleanup on shutdown pipelines.clear() logging.info("Pipelines cleaned up") # Create FastAPI app app = FastAPI( title="AI Lab API", description="API for various AI/ML pipelines using transformers", version="1.0.0", lifespan=lifespan, swagger_ui_parameters={ "syntaxHighlight.theme": "obsidian", "tryItOutEnabled": True, "requestSnippetsEnabled": True, "persistAuthorization": True, "displayRequestDuration": True, "defaultModelRendering": "model" } ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Configure appropriately for production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") async def root(): """Root endpoint""" return { "message": "Welcome to AI Lab API", "version": "1.0.0", "available_endpoints": [ "/sentiment", "/ner", "/qa", "/fillmask", "/moderation", "/textgen", "/sentiment/batch", "/ner/batch", "/fillmask/batch", "/moderation/batch", "/textgen/batch", "/health", "/docs" ] } @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "pipelines_loaded": len(pipelines), "available_pipelines": list(pipelines.keys()) } @app.post("/sentiment", response_model=SentimentResponse) async def analyze_sentiment(request: TextRequest): """Analyze sentiment of a text""" try: if "sentiment" not in pipelines: raise HTTPException(status_code=503, detail="Sentiment pipeline not available") # Use custom model if provided if request.model_name: from src.pipelines.sentiment import SentimentAnalyzer analyzer = SentimentAnalyzer(request.model_name) result = analyzer.analyze(request.text) else: result = pipelines["sentiment"].analyze(request.text) if "error" in result: return SentimentResponse(success=False, text=request.text, message=result["error"]) return SentimentResponse( success=True, text=result["text"], sentiment=result["sentiment"], confidence=result["confidence"] ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/ner", response_model=NERResponse) async def extract_entities(request: TextRequest): """Extract named entities from text""" try: logging.info(f"NER request for text: {request.text[:50]}...") if "ner" not in pipelines: raise HTTPException(status_code=503, detail="NER pipeline not available") try: if request.model_name: from src.pipelines.ner import NamedEntityRecognizer ner = NamedEntityRecognizer(request.model_name) result = ner.recognize(request.text) else: result = pipelines["ner"].recognize(request.text) except Exception as pipeline_error: logging.error(f"Pipeline error: {str(pipeline_error)}") return NERResponse(success=False, text=request.text, message=f"Pipeline error: {str(pipeline_error)}") logging.info(f"NER result keys: {list(result.keys())}") if "error" in result: logging.error(f"NER error: {result['error']}") return NERResponse(success=False, text=request.text, message=result["error"]) # Validate result structure if "original_text" not in result: logging.error(f"Missing 'original_text' in result: {result}") return NERResponse(success=False, text=request.text, message="Invalid NER result format") if "entities" not in result: logging.error(f"Missing 'entities' in result: {result}") return NERResponse(success=False, text=request.text, message="Invalid NER result format") return NERResponse( success=True, text=result["original_text"], entities=result["entities"] ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/qa", response_model=QAResponse) async def answer_question(request: QARequest): """Answer a question based on context""" try: if "qa" not in pipelines: raise HTTPException(status_code=503, detail="QA pipeline not available") if request.model_name: from src.pipelines.qa import QuestionAnsweringSystem qa = QuestionAnsweringSystem(request.model_name) result = qa.answer(request.question, request.context) else: result = pipelines["qa"].answer(request.question, request.context) if "error" in result: return QAResponse( success=False, question=request.question, context=request.context, message=result["error"] ) return QAResponse( success=True, question=result["question"], context=result["context"], answer=result["answer"], confidence=result["confidence"] ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/fillmask", response_model=FillMaskResponse) async def fill_mask(request: FillMaskRequest): """Fill masked words in text""" try: if "fillmask" not in pipelines: raise HTTPException(status_code=503, detail="Fill-mask pipeline not available") if request.model_name: from src.pipelines.fillmask import FillMaskAnalyzer fillmask = FillMaskAnalyzer(request.model_name) result = fillmask.predict(request.text) else: result = pipelines["fillmask"].predict(request.text) if "error" in result: return FillMaskResponse(success=False, text=request.text, message=result["error"]) return FillMaskResponse( success=True, text=result["original_text"], predictions=result["predictions"] ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/moderation", response_model=ModerationResponse) async def moderate_content(request: TextRequest): """Moderate content for inappropriate material""" try: if "moderation" not in pipelines: raise HTTPException(status_code=503, detail="Moderation pipeline not available") if request.model_name: from src.pipelines.moderation import ContentModerator moderation = ContentModerator(request.model_name) result = moderation.moderate(request.text) else: result = pipelines["moderation"].moderate(request.text) if "error" in result: return ModerationResponse(success=False, text=request.text, message=result["error"]) # Map the result fields correctly flagged = result.get("is_modified", False) or result.get("toxic_score", 0.0) > 0.5 categories = { "toxic_score": result.get("toxic_score", 0.0), "is_modified": result.get("is_modified", False), "restored_text": result.get("moderated_text", request.text), "words_replaced": result.get("words_replaced", 0) } return ModerationResponse( success=True, text=result["original_text"], flagged=flagged, categories=categories ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/textgen", response_model=TextGenResponse) async def generate_text(request: TextGenRequest): """Generate text from a prompt with configurable parameters""" try: if "textgen" not in pipelines: raise HTTPException(status_code=503, detail="Text generation pipeline not available") logging.info(f"Generating text for prompt: {request.text[:50]}...") # Extract generation parameters gen_params = { "system_prompt": request.system_prompt, "max_new_tokens": request.max_new_tokens, "num_return_sequences": request.num_return_sequences, "temperature": request.temperature, "do_sample": request.do_sample } if request.model_name: try: from src.pipelines.textgen import TextGenerator textgen = TextGenerator(request.model_name) result = textgen.generate(request.text, **gen_params) except torch.cuda.OutOfMemoryError: logging.warning(f"CUDA OOM with model {request.model_name}, trying with default model on CPU") result = pipelines["textgen"].generate(request.text, **gen_params) except Exception as model_error: logging.error(f"Error with custom model {request.model_name}: {model_error}") return TextGenResponse( success=False, prompt=request.text, message=f"Erreur avec le modèle {request.model_name}: {str(model_error)}. Essayez avec le modèle par défaut." ) else: result = pipelines["textgen"].generate(request.text, **gen_params) logging.info(f"Generation result keys: {list(result.keys())}") if "error" in result: logging.error(f"Generation error: {result['error']}") return TextGenResponse(success=False, prompt=request.text, message=result["error"]) # Extract the generated text from the first generation generated_text = "" if "generations" in result and len(result["generations"]) > 0: # Get the continuation (text after the prompt) from the first generation generated_text = result["generations"][0].get("continuation", "") logging.info(f"Extracted generated text: {generated_text[:100]}...") else: logging.warning("No generations found in result") return TextGenResponse( success=True, prompt=result["prompt"], generated_text=generated_text, parameters=result.get("parameters", {}), generations=result.get("generations", []) ) except Exception as e: logging.error(f"TextGen endpoint error: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) # Batch processing endpoints @app.post("/sentiment/batch", response_model=BatchResponse) async def analyze_sentiment_batch(request: TextListRequest): """Analyze sentiment for multiple texts""" try: if "sentiment" not in pipelines: raise HTTPException(status_code=503, detail="Sentiment pipeline not available") analyzer = pipelines["sentiment"] if request.model_name: from src.pipelines.sentiment import SentimentAnalyzer analyzer = SentimentAnalyzer(request.model_name) results = [] failed_count = 0 for text in request.texts: try: result = analyzer.analyze(text) if "error" in result: failed_count += 1 results.append(result) except Exception as e: failed_count += 1 results.append({"text": text, "error": str(e)}) return BatchResponse( success=True, results=results, processed_count=len(request.texts), failed_count=failed_count ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/ner/batch", response_model=BatchResponse) async def extract_entities_batch(request: TextListRequest): """Extract entities from multiple texts""" try: if "ner" not in pipelines: raise HTTPException(status_code=503, detail="NER pipeline not available") ner = pipelines["ner"] if request.model_name: from src.pipelines.ner import NamedEntityRecognizer ner = NamedEntityRecognizer(request.model_name) results = [] failed_count = 0 for text in request.texts: try: result = ner.recognize(text) if "error" in result: failed_count += 1 results.append(result) except Exception as e: failed_count += 1 results.append({"text": text, "error": str(e)}) return BatchResponse( success=True, results=results, processed_count=len(request.texts), failed_count=failed_count ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/fillmask/batch", response_model=BatchResponse) async def fill_mask_batch(request: TextListRequest): """Fill masked words in multiple texts""" try: if "fillmask" not in pipelines: raise HTTPException(status_code=503, detail="Fill-mask pipeline not available") fillmask = pipelines["fillmask"] if request.model_name: from src.pipelines.fillmask import FillMaskAnalyzer fillmask = FillMaskAnalyzer(request.model_name) results = [] failed_count = 0 for text in request.texts: try: result = fillmask.predict(text) if "error" in result: failed_count += 1 results.append(result) except Exception as e: failed_count += 1 results.append({"text": text, "error": str(e)}) return BatchResponse( success=True, results=results, processed_count=len(request.texts), failed_count=failed_count ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/moderation/batch", response_model=BatchResponse) async def moderate_content_batch(request: TextListRequest): """Moderate multiple texts for inappropriate content""" try: if "moderation" not in pipelines: raise HTTPException(status_code=503, detail="Moderation pipeline not available") moderation = pipelines["moderation"] if request.model_name: from src.pipelines.moderation import ContentModerator moderation = ContentModerator(request.model_name) results = [] failed_count = 0 for text in request.texts: try: result = moderation.moderate(text) if "error" in result: failed_count += 1 results.append(result) except Exception as e: failed_count += 1 results.append({"text": text, "error": str(e)}) return BatchResponse( success=True, results=results, processed_count=len(request.texts), failed_count=failed_count ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/textgen/batch", response_model=BatchResponse) async def generate_text_batch(request: TextListRequest): """Generate text from multiple prompts""" try: if "textgen" not in pipelines: raise HTTPException(status_code=503, detail="Text generation pipeline not available") textgen = pipelines["textgen"] if request.model_name: from src.pipelines.textgen import TextGenerator textgen = TextGenerator(request.model_name) results = [] failed_count = 0 for text in request.texts: try: result = textgen.generate(text) if "error" in result: failed_count += 1 results.append(result) except Exception as e: failed_count += 1 results.append({"text": text, "error": str(e)}) return BatchResponse( success=True, results=results, processed_count=len(request.texts), failed_count=failed_count ) except Exception as e: raise HTTPException(status_code=500, detail=str(e))