diff --git a/src/pipelines/textgen.py b/src/pipelines/textgen.py index e29d7c8..31b441a 100644 --- a/src/pipelines/textgen.py +++ b/src/pipelines/textgen.py @@ -30,7 +30,7 @@ class TextGenerator: print("Model loaded successfully!") - def generate(self, prompt: str, max_new_tokens: int = 100, num_return_sequences: int = 1, + def generate(self, prompt: str, max_new_tokens: int = 500, num_return_sequences: int = 1, temperature: float = 1.0, do_sample: bool = True) -> Dict: """ Generate text from a prompt