From faababc4ea5e4548bebc13b50509587343b4c2db Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 25 Apr 2025 16:42:30 -0700 Subject: [PATCH] llama.cpp: Add a prompt processing progress bar --- modules/llama_cpp_server.py | 45 ++++++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index 7e5e3a4b..85743705 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -1,6 +1,7 @@ import json import os import pprint +import re import socket import subprocess import sys @@ -10,6 +11,7 @@ from pathlib import Path import llama_cpp_binaries import requests +from tqdm import tqdm from modules import shared from modules.logging_colors import logger @@ -335,17 +337,7 @@ class LlamaServer: env=env ) - def filter_stderr(process_stderr): - try: - for line in iter(process_stderr.readline, ''): - if not line.startswith(('srv ', 'slot ')) and 'log_server_r: request: GET /health' not in line: - sys.stderr.write(line) - sys.stderr.flush() - except (ValueError, IOError): - # Handle pipe closed exceptions - pass - - threading.Thread(target=filter_stderr, args=(self.process.stderr,), daemon=True).start() + threading.Thread(target=filter_stderr_with_progress, args=(self.process.stderr,), daemon=True).start() # Wait for server to be healthy health_url = f"http://127.0.0.1:{self.port}/health" @@ -396,3 +388,34 @@ class LlamaServer: self.process.kill() self.process = None + + +def filter_stderr_with_progress(process_stderr): + progress_bar = None + progress_pattern = re.compile(r'slot update_slots: id.*progress = (\d+\.\d+)') + + try: + for line in iter(process_stderr.readline, ''): + progress_match = progress_pattern.search(line) + + if progress_match: + progress = float(progress_match.group(1)) + + # Create progress bar on first progress message + if progress_bar is None: + progress_bar = tqdm(total=1.0, desc="Prompt Processing", leave=False) + + progress_bar.update(progress - progress_bar.n) + + # Clean up when complete + if progress >= 1.0: + progress_bar.close() + progress_bar = None + + if not line.startswith(('srv ', 'slot ')) and 'log_server_r: request: GET /health' not in line: + sys.stderr.write(line) + sys.stderr.flush() + except (ValueError, IOError): + if progress_bar: + progress_bar.close() + pass