from transformers import pipeline from typing import Dict, List, Optional from src.config import Config class FillMaskAnalyzer: """Fill-mask analyzer using transformers""" def __init__(self, model_name: Optional[str] = None): """ Initialize the fill-mask pipeline Args: model_name: Name of the model to use (optional) """ self.model_name = model_name or Config.get_model("fillmask") print(f"Loading fill-mask model: {self.model_name}") self.pipeline = pipeline("fill-mask", model=self.model_name) print("Model loaded successfully!") def predict(self, text: str, top_k: int = 5) -> Dict: """ Predict masked tokens in text Args: text: Text with [MASK] token(s) to predict top_k: Number of top predictions to return Returns: Dictionary with predictions and scores """ if not text.strip(): return {"error": "Empty text"} if "[MASK]" not in text: return {"error": "Text must contain [MASK] token"} try: results = self.pipeline(text, top_k=top_k) # Handle single mask vs multiple masks if isinstance(results, list) and isinstance(results[0], list): # Multiple masks predictions = [] for i, mask_results in enumerate(results): mask_predictions = [ { "token": pred["token_str"], "score": round(float(pred["score"]), 4), "sequence": pred["sequence"] } for pred in mask_results ] predictions.append({ "mask_position": i + 1, "predictions": mask_predictions }) return { "original_text": text, "masks_count": len(results), "predictions": predictions } else: # Single mask predictions = [ { "token": pred["token_str"], "score": round(float(pred["score"]), 4), "sequence": pred["sequence"] } for pred in results ] return { "original_text": text, "masks_count": 1, "predictions": predictions } except Exception as e: return {"error": f"Prediction error: {str(e)}"} def predict_batch(self, texts: List[str], top_k: int = 5) -> List[Dict]: """ Predict masked tokens for multiple texts Args: texts: List of texts with [MASK] tokens top_k: Number of top predictions to return Returns: List of prediction results """ return [self.predict(text, top_k) for text in texts]