ai-lab-transformers-playground/src/pipelines/fillmask.py

96 lines
3.2 KiB
Python

from transformers import pipeline
from typing import Dict, List, Optional
from src.config import Config
class FillMaskAnalyzer:
"""Fill-mask analyzer using transformers"""
def __init__(self, model_name: Optional[str] = None):
"""
Initialize the fill-mask pipeline
Args:
model_name: Name of the model to use (optional)
"""
self.model_name = model_name or Config.get_model("fillmask")
print(f"Loading fill-mask model: {self.model_name}")
self.pipeline = pipeline("fill-mask", model=self.model_name)
print("Model loaded successfully!")
def predict(self, text: str, top_k: int = 5) -> Dict:
"""
Predict masked tokens in text
Args:
text: Text with [MASK] token(s) to predict
top_k: Number of top predictions to return
Returns:
Dictionary with predictions and scores
"""
if not text.strip():
return {"error": "Empty text"}
if "[MASK]" not in text:
return {"error": "Text must contain [MASK] token"}
try:
results = self.pipeline(text, top_k=top_k)
# Handle single mask vs multiple masks
if isinstance(results, list) and isinstance(results[0], list):
# Multiple masks
predictions = []
for i, mask_results in enumerate(results):
mask_predictions = [
{
"token": pred["token_str"],
"score": round(float(pred["score"]), 4),
"sequence": pred["sequence"]
}
for pred in mask_results
]
predictions.append({
"mask_position": i + 1,
"predictions": mask_predictions
})
return {
"original_text": text,
"masks_count": len(results),
"predictions": predictions
}
else:
# Single mask
predictions = [
{
"token": pred["token_str"],
"score": round(float(pred["score"]), 4),
"sequence": pred["sequence"]
}
for pred in results
]
return {
"original_text": text,
"masks_count": 1,
"predictions": predictions
}
except Exception as e:
return {"error": f"Prediction error: {str(e)}"}
def predict_batch(self, texts: List[str], top_k: int = 5) -> List[Dict]:
"""
Predict masked tokens for multiple texts
Args:
texts: List of texts with [MASK] tokens
top_k: Number of top predictions to return
Returns:
List of prediction results
"""
return [self.predict(text, top_k) for text in texts]