61 lines
1.7 KiB
Python
61 lines
1.7 KiB
Python
"""
|
|
API configuration settings
|
|
"""
|
|
from typing import Dict, Any
|
|
|
|
|
|
class APIConfig:
|
|
"""Configuration for the FastAPI application"""
|
|
|
|
# Server settings
|
|
DEFAULT_HOST = "127.0.0.1"
|
|
DEFAULT_PORT = 8000
|
|
|
|
# API settings
|
|
API_TITLE = "AI Lab API"
|
|
API_DESCRIPTION = "API for various AI/ML pipelines using transformers"
|
|
API_VERSION = "1.0.0"
|
|
|
|
# CORS settings
|
|
CORS_ORIGINS = ["*"] # Configure for production
|
|
CORS_METHODS = ["*"]
|
|
CORS_HEADERS = ["*"]
|
|
|
|
# Pipeline settings
|
|
MAX_TEXT_LENGTH = 10000
|
|
MAX_BATCH_SIZE = 100
|
|
|
|
# Model defaults
|
|
DEFAULT_MODELS: Dict[str, str] = {
|
|
"sentiment": "cardiffnlp/twitter-roberta-base-sentiment-latest",
|
|
"ner": "dbmdz/bert-large-cased-finetuned-conll03-english",
|
|
"qa": "distilbert-base-cased-distilled-squad",
|
|
"fillmask": "bert-base-uncased",
|
|
"moderation": "martin-ha/toxic-comment-model",
|
|
"textgen": "gpt2"
|
|
}
|
|
|
|
@classmethod
|
|
def get_default_model(cls, task: str) -> str:
|
|
"""Get default model for a task"""
|
|
return cls.DEFAULT_MODELS.get(task, "")
|
|
|
|
@classmethod
|
|
def get_all_settings(cls) -> Dict[str, Any]:
|
|
"""Get all configuration settings"""
|
|
return {
|
|
"server": {
|
|
"default_host": cls.DEFAULT_HOST,
|
|
"default_port": cls.DEFAULT_PORT
|
|
},
|
|
"api": {
|
|
"title": cls.API_TITLE,
|
|
"description": cls.API_DESCRIPTION,
|
|
"version": cls.API_VERSION
|
|
},
|
|
"limits": {
|
|
"max_text_length": cls.MAX_TEXT_LENGTH,
|
|
"max_batch_size": cls.MAX_BATCH_SIZE
|
|
},
|
|
"default_models": cls.DEFAULT_MODELS
|
|
} |