ai-lab-transformers-playground/src/pipelines/textgen.py

133 lines
4.6 KiB
Python

from click import prompt
from transformers import pipeline
from typing import Dict, List, Optional
import torch
from src.config import Config
class TextGenerator:
"""Text generator using transformers"""
def __init__(self, model_name: Optional[str] = None):
"""
Initialize the text-generation pipeline
Args:
model_name: Name of the model to use (optional)
"""
self.model_name = model_name or Config.get_model("textgen")
print(f"Loading text generation model: {self.model_name}")
# Clear GPU cache before loading new model
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Try GPU first, fallback to CPU if CUDA OOM
try:
# Initialize pipeline with proper device configuration
self.pipeline = pipeline(
"text-generation",
model=self.model_name,
device=0 if Config.USE_GPU else -1,
torch_dtype="auto"
)
print(f"Model loaded successfully on {'GPU' if Config.USE_GPU else 'CPU'}!")
except torch.cuda.OutOfMemoryError:
print("⚠️ GPU out of memory, falling back to CPU...")
# Force CPU usage
self.pipeline = pipeline(
"text-generation",
model=self.model_name,
device=-1, # CPU
torch_dtype="auto"
)
print("Model loaded successfully on CPU!")
except Exception as e:
print(f"⚠️ Error loading model on GPU, trying CPU: {e}")
# Fallback to CPU
self.pipeline = pipeline(
"text-generation",
model=self.model_name,
device=-1, # CPU
torch_dtype="auto"
)
print("Model loaded successfully on CPU!")
# Set pad token if not available
if self.pipeline.tokenizer.pad_token is None:
self.pipeline.tokenizer.pad_token = self.pipeline.tokenizer.eos_token
def generate(self, prompt: str, system_prompt: Optional[str] = None, max_new_tokens: int = 500,
num_return_sequences: int = 1, temperature: float = 1.0, do_sample: bool = True) -> Dict:
"""
Generate text from a prompt
Args:
prompt: Input text prompt
system_prompt: Optional system prompt to set context/role
max_new_tokens: Maximum number of new tokens to generate
num_return_sequences: Number of sequences to generate
temperature: Sampling temperature (higher = more random)
do_sample: Whether to use sampling
Returns:
Dictionary with generated texts
"""
if not prompt.strip():
return {"error": "Empty prompt"}
if system_prompt:
full_prompt = f"{system_prompt.strip()}\n\n{prompt.strip()}\n\n"
else:
full_prompt = f"{prompt.strip()}\n\n"
try:
results = self.pipeline(
full_prompt,
max_new_tokens=max_new_tokens,
num_return_sequences=num_return_sequences,
temperature=temperature,
do_sample=do_sample,
pad_token_id=self.pipeline.tokenizer.eos_token_id,
return_full_text=True
)
generations = [
{
"text": result["generated_text"],
"continuation": result["generated_text"][len(full_prompt):].strip()
}
for result in results
]
return {
"prompt": prompt,
"system_prompt": system_prompt,
"full_prompt": full_prompt,
"parameters": {
"max_new_tokens": max_new_tokens,
"num_sequences": num_return_sequences,
"temperature": temperature,
"do_sample": do_sample
},
"generations": generations
}
except Exception as e:
return {"error": f"Generation error: {str(e)}"}
def generate_batch(self, prompts: List[str], **kwargs) -> List[Dict]:
"""
Generate text for multiple prompts
Args:
prompts: List of input prompts
**kwargs: Generation parameters
Returns:
List of generation results
"""
return [self.generate(prompt, **kwargs) for prompt in prompts]