from click import prompt from transformers import pipeline from typing import Dict, List, Optional import torch from src.config import Config class TextGenerator: """Text generator using transformers""" def __init__(self, model_name: Optional[str] = None): """ Initialize the text-generation pipeline Args: model_name: Name of the model to use (optional) """ self.model_name = model_name or Config.get_model("textgen") print(f"Loading text generation model: {self.model_name}") # Clear GPU cache before loading new model if torch.cuda.is_available(): torch.cuda.empty_cache() # Try GPU first, fallback to CPU if CUDA OOM try: # Initialize pipeline with proper device configuration self.pipeline = pipeline( "text-generation", model=self.model_name, device=0 if Config.USE_GPU else -1, torch_dtype="auto" ) print(f"Model loaded successfully on {'GPU' if Config.USE_GPU else 'CPU'}!") except torch.cuda.OutOfMemoryError: print("⚠️ GPU out of memory, falling back to CPU...") # Force CPU usage self.pipeline = pipeline( "text-generation", model=self.model_name, device=-1, # CPU torch_dtype="auto" ) print("Model loaded successfully on CPU!") except Exception as e: print(f"⚠️ Error loading model on GPU, trying CPU: {e}") # Fallback to CPU self.pipeline = pipeline( "text-generation", model=self.model_name, device=-1, # CPU torch_dtype="auto" ) print("Model loaded successfully on CPU!") # Set pad token if not available if self.pipeline.tokenizer.pad_token is None: self.pipeline.tokenizer.pad_token = self.pipeline.tokenizer.eos_token def generate(self, prompt: str, system_prompt: Optional[str] = None, max_new_tokens: int = 500, num_return_sequences: int = 1, temperature: float = 1.0, do_sample: bool = True) -> Dict: """ Generate text from a prompt Args: prompt: Input text prompt system_prompt: Optional system prompt to set context/role max_new_tokens: Maximum number of new tokens to generate num_return_sequences: Number of sequences to generate temperature: Sampling temperature (higher = more random) do_sample: Whether to use sampling Returns: Dictionary with generated texts """ if not prompt.strip(): return {"error": "Empty prompt"} if system_prompt: full_prompt = f"{system_prompt.strip()}\n\n{prompt.strip()}\n\n" else: full_prompt = f"{prompt.strip()}\n\n" try: results = self.pipeline( full_prompt, max_new_tokens=max_new_tokens, num_return_sequences=num_return_sequences, temperature=temperature, do_sample=do_sample, pad_token_id=self.pipeline.tokenizer.eos_token_id, return_full_text=True ) generations = [ { "text": result["generated_text"], "continuation": result["generated_text"][len(full_prompt):].strip() } for result in results ] return { "prompt": prompt, "system_prompt": system_prompt, "full_prompt": full_prompt, "parameters": { "max_new_tokens": max_new_tokens, "num_sequences": num_return_sequences, "temperature": temperature, "do_sample": do_sample }, "generations": generations } except Exception as e: return {"error": f"Generation error: {str(e)}"} def generate_batch(self, prompts: List[str], **kwargs) -> List[Dict]: """ Generate text for multiple prompts Args: prompts: List of input prompts **kwargs: Generation parameters Returns: List of generation results """ return [self.generate(prompt, **kwargs) for prompt in prompts]