44 lines
1.2 KiB
Python
44 lines
1.2 KiB
Python
"""
|
|
Global project configuration
|
|
"""
|
|
from pathlib import Path
|
|
from typing import Dict, Any
|
|
import torch
|
|
|
|
|
|
class Config:
|
|
"""Global application configuration"""
|
|
|
|
# Paths
|
|
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
|
SRC_DIR = PROJECT_ROOT / "src"
|
|
|
|
# Default models
|
|
DEFAULT_MODELS = {
|
|
"sentiment": "distilbert-base-uncased-finetuned-sst-2-english",
|
|
"fillmask": "bert-base-uncased",
|
|
"textgen": "gpt2",
|
|
"ner": "dslim/bert-base-NER",
|
|
"moderation":"unitary/toxic-bert",
|
|
"qa": "distilbert-base-cased-distilled-squad",
|
|
}
|
|
|
|
# Interface
|
|
CLI_BANNER = "🤖 AI Lab - Transformers Experimentation"
|
|
CLI_SEPARATOR = "=" * 50
|
|
|
|
# Performance
|
|
MAX_BATCH_SIZE = 32
|
|
DEFAULT_MAX_LENGTH = 512
|
|
USE_GPU = torch.cuda.is_available() # Auto-detect GPU availability
|
|
|
|
@classmethod
|
|
def get_model(cls, pipeline_name: str) -> str:
|
|
"""Get default model for a pipeline"""
|
|
return cls.DEFAULT_MODELS.get(pipeline_name, "")
|
|
|
|
@classmethod
|
|
def get_all_models(cls) -> Dict[str, str]:
|
|
"""Get all configured models"""
|
|
return cls.DEFAULT_MODELS.copy()
|