ai-lab-transformers-playground/src/config/settings.py

44 lines
1.2 KiB
Python

"""
Global project configuration
"""
from pathlib import Path
from typing import Dict, Any
import torch
class Config:
"""Global application configuration"""
# Paths
PROJECT_ROOT = Path(__file__).parent.parent.parent
SRC_DIR = PROJECT_ROOT / "src"
# Default models
DEFAULT_MODELS = {
"sentiment": "distilbert-base-uncased-finetuned-sst-2-english",
"fillmask": "bert-base-uncased",
"textgen": "gpt2",
"ner": "dslim/bert-base-NER",
"moderation":"unitary/toxic-bert",
"qa": "distilbert-base-cased-distilled-squad",
}
# Interface
CLI_BANNER = "🤖 AI Lab - Transformers Experimentation"
CLI_SEPARATOR = "=" * 50
# Performance
MAX_BATCH_SIZE = 32
DEFAULT_MAX_LENGTH = 512
USE_GPU = torch.cuda.is_available() # Auto-detect GPU availability
@classmethod
def get_model(cls, pipeline_name: str) -> str:
"""Get default model for a pipeline"""
return cls.DEFAULT_MODELS.get(pipeline_name, "")
@classmethod
def get_all_models(cls) -> Dict[str, str]:
"""Get all configured models"""
return cls.DEFAULT_MODELS.copy()