133 lines
4.6 KiB
Python
133 lines
4.6 KiB
Python
from click import prompt
|
|
from transformers import pipeline
|
|
from typing import Dict, List, Optional
|
|
import torch
|
|
from src.config import Config
|
|
|
|
|
|
class TextGenerator:
|
|
"""Text generator using transformers"""
|
|
|
|
def __init__(self, model_name: Optional[str] = None):
|
|
"""
|
|
Initialize the text-generation pipeline
|
|
|
|
Args:
|
|
model_name: Name of the model to use (optional)
|
|
"""
|
|
self.model_name = model_name or Config.get_model("textgen")
|
|
print(f"Loading text generation model: {self.model_name}")
|
|
|
|
# Clear GPU cache before loading new model
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
# Try GPU first, fallback to CPU if CUDA OOM
|
|
try:
|
|
# Initialize pipeline with proper device configuration
|
|
self.pipeline = pipeline(
|
|
"text-generation",
|
|
model=self.model_name,
|
|
device=0 if Config.USE_GPU else -1,
|
|
torch_dtype="auto"
|
|
)
|
|
print(f"Model loaded successfully on {'GPU' if Config.USE_GPU else 'CPU'}!")
|
|
|
|
except torch.cuda.OutOfMemoryError:
|
|
print("⚠️ GPU out of memory, falling back to CPU...")
|
|
# Force CPU usage
|
|
self.pipeline = pipeline(
|
|
"text-generation",
|
|
model=self.model_name,
|
|
device=-1, # CPU
|
|
torch_dtype="auto"
|
|
)
|
|
print("Model loaded successfully on CPU!")
|
|
|
|
except Exception as e:
|
|
print(f"⚠️ Error loading model on GPU, trying CPU: {e}")
|
|
# Fallback to CPU
|
|
self.pipeline = pipeline(
|
|
"text-generation",
|
|
model=self.model_name,
|
|
device=-1, # CPU
|
|
torch_dtype="auto"
|
|
)
|
|
print("Model loaded successfully on CPU!")
|
|
|
|
# Set pad token if not available
|
|
if self.pipeline.tokenizer.pad_token is None:
|
|
self.pipeline.tokenizer.pad_token = self.pipeline.tokenizer.eos_token
|
|
|
|
def generate(self, prompt: str, system_prompt: Optional[str] = None, max_new_tokens: int = 500,
|
|
num_return_sequences: int = 1, temperature: float = 1.0, do_sample: bool = True) -> Dict:
|
|
"""
|
|
Generate text from a prompt
|
|
|
|
Args:
|
|
prompt: Input text prompt
|
|
system_prompt: Optional system prompt to set context/role
|
|
max_new_tokens: Maximum number of new tokens to generate
|
|
num_return_sequences: Number of sequences to generate
|
|
temperature: Sampling temperature (higher = more random)
|
|
do_sample: Whether to use sampling
|
|
|
|
Returns:
|
|
Dictionary with generated texts
|
|
"""
|
|
if not prompt.strip():
|
|
return {"error": "Empty prompt"}
|
|
|
|
if system_prompt:
|
|
full_prompt = f"{system_prompt.strip()}\n\n{prompt.strip()}\n\n"
|
|
else:
|
|
full_prompt = f"{prompt.strip()}\n\n"
|
|
|
|
try:
|
|
results = self.pipeline(
|
|
full_prompt,
|
|
max_new_tokens=max_new_tokens,
|
|
num_return_sequences=num_return_sequences,
|
|
temperature=temperature,
|
|
do_sample=do_sample,
|
|
pad_token_id=self.pipeline.tokenizer.eos_token_id,
|
|
return_full_text=True
|
|
)
|
|
|
|
generations = [
|
|
{
|
|
"text": result["generated_text"],
|
|
"continuation": result["generated_text"][len(full_prompt):].strip()
|
|
}
|
|
for result in results
|
|
]
|
|
|
|
return {
|
|
"prompt": prompt,
|
|
"system_prompt": system_prompt,
|
|
"full_prompt": full_prompt,
|
|
"parameters": {
|
|
"max_new_tokens": max_new_tokens,
|
|
"num_sequences": num_return_sequences,
|
|
"temperature": temperature,
|
|
"do_sample": do_sample
|
|
},
|
|
"generations": generations
|
|
}
|
|
|
|
except Exception as e:
|
|
return {"error": f"Generation error: {str(e)}"}
|
|
|
|
def generate_batch(self, prompts: List[str], **kwargs) -> List[Dict]:
|
|
"""
|
|
Generate text for multiple prompts
|
|
|
|
Args:
|
|
prompts: List of input prompts
|
|
**kwargs: Generation parameters
|
|
|
|
Returns:
|
|
List of generation results
|
|
"""
|
|
return [self.generate(prompt, **kwargs) for prompt in prompts]
|