mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-01-07 17:20:19 +01:00
llama.cpp: Add a prompt processing progress bar
This commit is contained in:
parent
877cf44c08
commit
faababc4ea
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import os
|
||||
import pprint
|
||||
import re
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
|
|
@ -10,6 +11,7 @@ from pathlib import Path
|
|||
|
||||
import llama_cpp_binaries
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
|
|
@ -335,17 +337,7 @@ class LlamaServer:
|
|||
env=env
|
||||
)
|
||||
|
||||
def filter_stderr(process_stderr):
|
||||
try:
|
||||
for line in iter(process_stderr.readline, ''):
|
||||
if not line.startswith(('srv ', 'slot ')) and 'log_server_r: request: GET /health' not in line:
|
||||
sys.stderr.write(line)
|
||||
sys.stderr.flush()
|
||||
except (ValueError, IOError):
|
||||
# Handle pipe closed exceptions
|
||||
pass
|
||||
|
||||
threading.Thread(target=filter_stderr, args=(self.process.stderr,), daemon=True).start()
|
||||
threading.Thread(target=filter_stderr_with_progress, args=(self.process.stderr,), daemon=True).start()
|
||||
|
||||
# Wait for server to be healthy
|
||||
health_url = f"http://127.0.0.1:{self.port}/health"
|
||||
|
|
@ -396,3 +388,34 @@ class LlamaServer:
|
|||
self.process.kill()
|
||||
|
||||
self.process = None
|
||||
|
||||
|
||||
def filter_stderr_with_progress(process_stderr):
|
||||
progress_bar = None
|
||||
progress_pattern = re.compile(r'slot update_slots: id.*progress = (\d+\.\d+)')
|
||||
|
||||
try:
|
||||
for line in iter(process_stderr.readline, ''):
|
||||
progress_match = progress_pattern.search(line)
|
||||
|
||||
if progress_match:
|
||||
progress = float(progress_match.group(1))
|
||||
|
||||
# Create progress bar on first progress message
|
||||
if progress_bar is None:
|
||||
progress_bar = tqdm(total=1.0, desc="Prompt Processing", leave=False)
|
||||
|
||||
progress_bar.update(progress - progress_bar.n)
|
||||
|
||||
# Clean up when complete
|
||||
if progress >= 1.0:
|
||||
progress_bar.close()
|
||||
progress_bar = None
|
||||
|
||||
if not line.startswith(('srv ', 'slot ')) and 'log_server_r: request: GET /health' not in line:
|
||||
sys.stderr.write(line)
|
||||
sys.stderr.flush()
|
||||
except (ValueError, IOError):
|
||||
if progress_bar:
|
||||
progress_bar.close()
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in a new issue