diff --git a/modules/text_generation.py b/modules/text_generation.py index 834101ff..1f6a2819 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -73,6 +73,8 @@ def formatted_outputs(reply, model_name): return reply def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None): + torch.cuda.empty_cache() + original_question = question if not (shared.args.chat or shared.args.cai_chat): question = apply_extensions(question, "input")