96 lines
3.2 KiB
Python
96 lines
3.2 KiB
Python
from transformers import pipeline
|
|
from typing import Dict, List, Optional
|
|
from src.config import Config
|
|
|
|
|
|
class FillMaskAnalyzer:
|
|
"""Fill-mask analyzer using transformers"""
|
|
|
|
def __init__(self, model_name: Optional[str] = None):
|
|
"""
|
|
Initialize the fill-mask pipeline
|
|
|
|
Args:
|
|
model_name: Name of the model to use (optional)
|
|
"""
|
|
self.model_name = model_name or Config.get_model("fillmask")
|
|
print(f"Loading fill-mask model: {self.model_name}")
|
|
self.pipeline = pipeline("fill-mask", model=self.model_name)
|
|
print("Model loaded successfully!")
|
|
|
|
def predict(self, text: str, top_k: int = 5) -> Dict:
|
|
"""
|
|
Predict masked tokens in text
|
|
|
|
Args:
|
|
text: Text with [MASK] token(s) to predict
|
|
top_k: Number of top predictions to return
|
|
|
|
Returns:
|
|
Dictionary with predictions and scores
|
|
"""
|
|
if not text.strip():
|
|
return {"error": "Empty text"}
|
|
|
|
if "[MASK]" not in text:
|
|
return {"error": "Text must contain [MASK] token"}
|
|
|
|
try:
|
|
results = self.pipeline(text, top_k=top_k)
|
|
|
|
# Handle single mask vs multiple masks
|
|
if isinstance(results, list) and isinstance(results[0], list):
|
|
# Multiple masks
|
|
predictions = []
|
|
for i, mask_results in enumerate(results):
|
|
mask_predictions = [
|
|
{
|
|
"token": pred["token_str"],
|
|
"score": round(float(pred["score"]), 4),
|
|
"sequence": pred["sequence"]
|
|
}
|
|
for pred in mask_results
|
|
]
|
|
predictions.append({
|
|
"mask_position": i + 1,
|
|
"predictions": mask_predictions
|
|
})
|
|
|
|
return {
|
|
"original_text": text,
|
|
"masks_count": len(results),
|
|
"predictions": predictions
|
|
}
|
|
else:
|
|
# Single mask
|
|
predictions = [
|
|
{
|
|
"token": pred["token_str"],
|
|
"score": round(float(pred["score"]), 4),
|
|
"sequence": pred["sequence"]
|
|
}
|
|
for pred in results
|
|
]
|
|
|
|
return {
|
|
"original_text": text,
|
|
"masks_count": 1,
|
|
"predictions": predictions
|
|
}
|
|
|
|
except Exception as e:
|
|
return {"error": f"Prediction error: {str(e)}"}
|
|
|
|
def predict_batch(self, texts: List[str], top_k: int = 5) -> List[Dict]:
|
|
"""
|
|
Predict masked tokens for multiple texts
|
|
|
|
Args:
|
|
texts: List of texts with [MASK] tokens
|
|
top_k: Number of top predictions to return
|
|
|
|
Returns:
|
|
List of prediction results
|
|
"""
|
|
return [self.predict(text, top_k) for text in texts]
|