mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-06 07:03:37 +00:00
Fix --idle-timeout issues with encode/decode and parallel generation
This commit is contained in:
parent
d6f1485dd1
commit
368f37335f
3 changed files with 28 additions and 9 deletions
|
|
@ -4,7 +4,6 @@ import numpy as np
|
|||
|
||||
from modules import models, shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.models import load_model
|
||||
from modules.text_generation import generate_reply
|
||||
from modules.utils import check_model_loaded
|
||||
|
||||
|
|
@ -12,8 +11,7 @@ global_scores = None
|
|||
|
||||
|
||||
def get_next_logits(*args, **kwargs):
|
||||
if shared.args.idle_timeout > 0 and shared.model is None and shared.model_name not in [None, 'None']:
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
models.load_model_if_idle_unloaded()
|
||||
|
||||
needs_lock = not args[2] # use_samplers
|
||||
if needs_lock:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
import modules.shared as shared
|
||||
|
|
@ -7,6 +8,15 @@ from modules.models_settings import get_model_metadata
|
|||
from modules.utils import resolve_model_path
|
||||
|
||||
last_generation_time = time.time()
|
||||
active_generation_count = 0
|
||||
_generation_count_lock = threading.Lock()
|
||||
|
||||
|
||||
def load_model_if_idle_unloaded():
|
||||
global last_generation_time
|
||||
if shared.args.idle_timeout > 0 and shared.model is None and shared.model_name not in [None, 'None']:
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
last_generation_time = time.time()
|
||||
|
||||
|
||||
def load_model(model_name, loader=None):
|
||||
|
|
@ -158,7 +168,10 @@ def unload_model_if_idle():
|
|||
while True:
|
||||
shared.generation_lock.acquire()
|
||||
try:
|
||||
if time.time() - last_generation_time > shared.args.idle_timeout * 60:
|
||||
with _generation_count_lock:
|
||||
is_active = active_generation_count > 0
|
||||
|
||||
if not is_active and time.time() - last_generation_time > shared.args.idle_timeout * 60:
|
||||
if shared.model is not None:
|
||||
logger.info("Unloading the model for inactivity.")
|
||||
unload_model(keep_model_name=True)
|
||||
|
|
|
|||
|
|
@ -17,9 +17,7 @@ from modules.utils import check_model_loaded
|
|||
|
||||
|
||||
def generate_reply(*args, **kwargs):
|
||||
if shared.args.idle_timeout > 0 and shared.model is None and shared.model_name not in [None, 'None']:
|
||||
from modules.models import load_model
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
models.load_model_if_idle_unloaded()
|
||||
|
||||
state = args[1] if len(args) > 1 else kwargs.get('state', {})
|
||||
use_parallel = (
|
||||
|
|
@ -31,10 +29,16 @@ def generate_reply(*args, **kwargs):
|
|||
if not use_parallel:
|
||||
shared.generation_lock.acquire()
|
||||
|
||||
with models._generation_count_lock:
|
||||
models.active_generation_count += 1
|
||||
|
||||
try:
|
||||
for result in _generate_reply(*args, **kwargs):
|
||||
yield result
|
||||
finally:
|
||||
with models._generation_count_lock:
|
||||
models.active_generation_count -= 1
|
||||
|
||||
models.last_generation_time = time.time()
|
||||
if not use_parallel:
|
||||
shared.generation_lock.release()
|
||||
|
|
@ -126,7 +130,9 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
|||
|
||||
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
||||
if shared.tokenizer is None:
|
||||
raise ValueError('No tokenizer is loaded')
|
||||
models.load_model_if_idle_unloaded()
|
||||
if shared.tokenizer is None:
|
||||
raise ValueError('No tokenizer is loaded')
|
||||
|
||||
# llama.cpp case
|
||||
if shared.model.__class__.__name__ == 'LlamaServer':
|
||||
|
|
@ -176,7 +182,9 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||
|
||||
def decode(output_ids, skip_special_tokens=True):
|
||||
if shared.tokenizer is None:
|
||||
raise ValueError('No tokenizer is loaded')
|
||||
models.load_model_if_idle_unloaded()
|
||||
if shared.tokenizer is None:
|
||||
raise ValueError('No tokenizer is loaded')
|
||||
|
||||
return shared.tokenizer.decode(output_ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue