diff --git a/modules/api/completions.py b/modules/api/completions.py index 98bcff47..f2282731 100644 --- a/modules/api/completions.py +++ b/modules/api/completions.py @@ -89,6 +89,7 @@ def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None): return [{"token": first_token_str, "null_logprob": True}] import torch + from modules.torch_utils import clear_torch_cache if hasattr(model, 'get_prompt_logits'): logits = model.get_prompt_logits(input_ids) @@ -143,7 +144,7 @@ def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None): del chunk_logits, chunk_lse, chunk_top_values del logits - torch.cuda.empty_cache() + clear_torch_cache() all_top_log_probs = torch.cat(all_top_log_probs_list, dim=0) all_top_indices = torch.cat(all_top_indices_list, dim=0)