Fix --idle-timeout issues with encode/decode and parallel generation

This commit is contained in:
oobabooga 2026-03-25 06:37:45 -07:00
parent d6f1485dd1
commit 368f37335f
3 changed files with 28 additions and 9 deletions

View file

@ -4,7 +4,6 @@ import numpy as np
from modules import models, shared
from modules.logging_colors import logger
from modules.models import load_model
from modules.text_generation import generate_reply
from modules.utils import check_model_loaded
@ -12,8 +11,7 @@ global_scores = None
def get_next_logits(*args, **kwargs):
if shared.args.idle_timeout > 0 and shared.model is None and shared.model_name not in [None, 'None']:
shared.model, shared.tokenizer = load_model(shared.model_name)
models.load_model_if_idle_unloaded()
needs_lock = not args[2] # use_samplers
if needs_lock:

View file

@ -1,4 +1,5 @@
import sys
import threading
import time
import modules.shared as shared
@ -7,6 +8,15 @@ from modules.models_settings import get_model_metadata
from modules.utils import resolve_model_path
last_generation_time = time.time()
active_generation_count = 0
_generation_count_lock = threading.Lock()
def load_model_if_idle_unloaded():
global last_generation_time
if shared.args.idle_timeout > 0 and shared.model is None and shared.model_name not in [None, 'None']:
shared.model, shared.tokenizer = load_model(shared.model_name)
last_generation_time = time.time()
def load_model(model_name, loader=None):
@ -158,7 +168,10 @@ def unload_model_if_idle():
while True:
shared.generation_lock.acquire()
try:
if time.time() - last_generation_time > shared.args.idle_timeout * 60:
with _generation_count_lock:
is_active = active_generation_count > 0
if not is_active and time.time() - last_generation_time > shared.args.idle_timeout * 60:
if shared.model is not None:
logger.info("Unloading the model for inactivity.")
unload_model(keep_model_name=True)

View file

@ -17,9 +17,7 @@ from modules.utils import check_model_loaded
def generate_reply(*args, **kwargs):
if shared.args.idle_timeout > 0 and shared.model is None and shared.model_name not in [None, 'None']:
from modules.models import load_model
shared.model, shared.tokenizer = load_model(shared.model_name)
models.load_model_if_idle_unloaded()
state = args[1] if len(args) > 1 else kwargs.get('state', {})
use_parallel = (
@ -31,10 +29,16 @@ def generate_reply(*args, **kwargs):
if not use_parallel:
shared.generation_lock.acquire()
with models._generation_count_lock:
models.active_generation_count += 1
try:
for result in _generate_reply(*args, **kwargs):
yield result
finally:
with models._generation_count_lock:
models.active_generation_count -= 1
models.last_generation_time = time.time()
if not use_parallel:
shared.generation_lock.release()
@ -126,7 +130,9 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
if shared.tokenizer is None:
raise ValueError('No tokenizer is loaded')
models.load_model_if_idle_unloaded()
if shared.tokenizer is None:
raise ValueError('No tokenizer is loaded')
# llama.cpp case
if shared.model.__class__.__name__ == 'LlamaServer':
@ -176,7 +182,9 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
def decode(output_ids, skip_special_tokens=True):
if shared.tokenizer is None:
raise ValueError('No tokenizer is loaded')
models.load_model_if_idle_unloaded()
if shared.tokenizer is None:
raise ValueError('No tokenizer is loaded')
return shared.tokenizer.decode(output_ids, skip_special_tokens=skip_special_tokens)