diff --git a/modules/text_generation.py b/modules/text_generation.py index cbe5b61b..d62441df 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -263,6 +263,9 @@ def apply_stopping_strings(reply, all_stop_strings): def get_reply_from_output_ids(output_ids, state=None, starting_from=0): + if torch.cuda.is_available(): + torch.cuda.synchronize() + reply = decode(output_ids[starting_from:], state['skip_special_tokens'] if state else True) # Handle tokenizers that do not add the leading space for the first token