diff --git a/.github/workflows/build-portable-release-rocm.yml b/.github/workflows/build-portable-release-rocm.yml index 6f9ea4ec..1050fa7e 100644 --- a/.github/workflows/build-portable-release-rocm.yml +++ b/.github/workflows/build-portable-release-rocm.yml @@ -148,11 +148,11 @@ jobs: # 6. Create archive cd .. if [[ "$RUNNER_OS" == "Windows" ]]; then - ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm.zip" + ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm7.2.zip" echo "Creating archive: $ARCHIVE_NAME" powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ARCHIVE_NAME" else - ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm.tar.gz" + ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm7.2.tar.gz" echo "Creating archive: $ARCHIVE_NAME" tar czf "$ARCHIVE_NAME" "text-generation-webui-${VERSION_CLEAN}" fi diff --git a/README.md b/README.md index a47a0e6f..989659d1 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ # Text Generation Web UI -A Gradio web UI for running Large Language Models locally. 100% private, offline, and free. +A Gradio web UI for running Large Language Models locally. 100% private and offline. Supports text generation, vision, tool-calling, training, image generation, and more. [Try the Deep Reason extension](https://oobabooga.gumroad.com/l/deep_reason) @@ -23,22 +23,21 @@ A Gradio web UI for running Large Language Models locally. 100% private, offline ## Features -- Supports multiple local text generation backends, including [llama.cpp](https://github.com/ggerganov/llama.cpp), [Transformers](https://github.com/huggingface/transformers), [ExLlamaV3](https://github.com/turboderp-org/exllamav3), and [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) (the latter via its own [Dockerfile](https://github.com/oobabooga/text-generation-webui/blob/main/docker/TensorRT-LLM/Dockerfile)). -- Easy setup: Choose between **portable builds** (zero setup, just unzip and run) for GGUF models on Windows/Linux/macOS, or the one-click installer that creates a self-contained `installer_files` directory. +- **Multiple backends**: [llama.cpp](https://github.com/ggerganov/llama.cpp), [Transformers](https://github.com/huggingface/transformers), [ExLlamaV3](https://github.com/turboderp-org/exllamav3), and [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM). Switch between backends and models without restarting. +- **File attachments**: Upload text files, PDF documents, and .docx documents to talk about their contents. +- **Vision (multimodal)**: Attach images to messages for visual understanding ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Multimodal-Tutorial)). +- **Tool-calling**: Models can call custom functions during chat — web search, page fetching, math, and more. Each tool is a single `.py` file, easy to create and extend ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Tool-Calling-Tutorial)). +- **OpenAI-compatible API**: Chat and Completions endpoints with tool-calling support. Use as a local drop-in replacement for the OpenAI API ([examples](https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples)). +- **Training**: Fine-tune LoRAs on multi-turn chat or raw text datasets. Supports resuming interrupted runs ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/05-%E2%80%90-Training-Tab)). +- **Image generation**: A dedicated tab for `diffusers` models like **Z-Image-Turbo**. Features 4-bit/8-bit quantization and a persistent gallery with metadata ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Image-Generation-Tutorial)). +- **Easy setup**: [Portable builds](https://github.com/oobabooga/text-generation-webui/releases) (zero setup, just unzip and run) for GGUF models on Windows/Linux/macOS, or a one-click installer for the full feature set. - 100% offline and private, with zero telemetry, external resources, or remote update requests. - `instruct` mode for instruction-following (like ChatGPT), and `chat-instruct`/`chat` modes for talking to custom characters. Prompts are automatically formatted with Jinja2 templates. -- **File attachments**: Upload text files, PDF documents, and .docx documents to talk about their contents. -- **Vision (multimodal models)**: Attach images to messages for visual understanding ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Multimodal-Tutorial)). -- **Image generation**: A dedicated tab for `diffusers` models like **Z-Image-Turbo**. Features 4-bit/8-bit quantization and a persistent gallery with metadata ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Image-Generation-Tutorial)). -- **Web search**: Optionally search the internet with LLM-generated queries to add context to the conversation. -- **Training**: Fine-tune LoRAs on multi-turn chat or raw text datasets. Easy to use, good defaults, and supports resuming interrupted runs ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/05-%E2%80%90-Training-Tab)). - Edit messages, navigate between message versions, and branch conversations at any point. -- Switch between different models in the UI without restarting. - Free-form text generation in the Notebook tab without being limited to chat turns. - Multiple sampling parameters and generation options for sophisticated text generation control. - Aesthetic UI with dark and light themes. - Syntax highlighting for code blocks and LaTeX rendering for mathematical expressions. -- OpenAI-compatible API with Chat and Completions endpoints, including tool-calling support – see [examples](https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples). - Extension support, with numerous built-in and user-contributed extensions available. See the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/07-%E2%80%90-Extensions) and [extensions directory](https://github.com/oobabooga/text-generation-webui-extensions) for details. ## How to install @@ -47,10 +46,11 @@ A Gradio web UI for running Large Language Models locally. 100% private, offline No installation needed – just download, unzip and run. All dependencies included. -Compatible with GGUF (llama.cpp) models on Windows, Linux, and macOS. [Check what models fit your hardware](https://huggingface.co/spaces/oobabooga/accurate-gguf-vram-calculator). - Download from here: **https://github.com/oobabooga/text-generation-webui/releases** +- Builds are provided for Linux, Windows, and macOS, with options for CUDA, Vulkan, ROCm, and CPU-only. +- Compatible with GGUF (llama.cpp) models. + #### Option 2: Manual portable install with venv Very fast setup that should work on any Python 3.9+: @@ -81,7 +81,7 @@ deactivate #### Option 3: One-click installer -For users who need additional backends (ExLlamaV3, Transformers) or extensions (TTS, voice input, translation, etc). Requires ~10GB disk space and downloads PyTorch. +For users who need additional backends (ExLlamaV3, Transformers), training, image generation, or extensions (TTS, voice input, translation, etc). Requires ~10GB disk space and downloads PyTorch. 1. Clone the repository, or [download its source code](https://github.com/oobabooga/text-generation-webui/archive/refs/heads/main.zip) and extract it. 2. Run the startup script for your OS: `start_windows.bat`, `start_linux.sh`, or `start_macos.sh`. @@ -146,7 +146,7 @@ conda activate textgen |--------|---------|---------| | Linux/WSL | NVIDIA | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128` | | Linux/WSL | CPU only | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cpu` | -| Linux | AMD | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/rocm6.4` | +| Linux | AMD | `pip3 install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.9.1%2Brocm7.2.0.lw.git7e1940d4-cp313-cp313-linux_x86_64.whl` | | MacOS + MPS | Any | `pip3 install torch==2.9.1` | | Windows | NVIDIA | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128` | | Windows | CPU only | `pip3 install torch==2.9.1` | @@ -201,7 +201,7 @@ ln -s docker/{nvidia/Dockerfile,nvidia/docker-compose.yml,.dockerignore} . For AMD GPU: ln -s docker/{amd/Dockerfile,amd/docker-compose.yml,.dockerignore} . For Intel GPU: -ln -s docker/{intel/Dockerfile,amd/docker-compose.yml,.dockerignore} . +ln -s docker/{intel/Dockerfile,intel/docker-compose.yml,.dockerignore} . For CPU only ln -s docker/{cpu/Dockerfile,cpu/docker-compose.yml,.dockerignore} . cp docker/.env.example .env @@ -236,20 +236,24 @@ List of command-line flags ```txt -usage: server.py [-h] [--multi-user] [--model MODEL] [--lora LORA [LORA ...]] [--model-dir MODEL_DIR] [--lora-dir LORA_DIR] [--model-menu] [--settings SETTINGS] +usage: server.py [-h] [--user-data-dir USER_DATA_DIR] [--multi-user] [--model MODEL] [--lora LORA [LORA ...]] [--model-dir MODEL_DIR] [--lora-dir LORA_DIR] [--model-menu] [--settings SETTINGS] [--extensions EXTENSIONS [EXTENSIONS ...]] [--verbose] [--idle-timeout IDLE_TIMEOUT] [--image-model IMAGE_MODEL] [--image-model-dir IMAGE_MODEL_DIR] [--image-dtype {bfloat16,float16}] [--image-attn-backend {flash_attention_2,sdpa}] [--image-cpu-offload] [--image-compile] [--image-quant {none,bnb-8bit,bnb-4bit,torchao-int8wo,torchao-fp4,torchao-float8wo}] [--loader LOADER] [--ctx-size N] [--cache-type N] [--model-draft MODEL_DRAFT] [--draft-max DRAFT_MAX] [--gpu-layers-draft GPU_LAYERS_DRAFT] [--device-draft DEVICE_DRAFT] [--ctx-size-draft CTX_SIZE_DRAFT] [--spec-type {none,ngram-mod,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-cache}] [--spec-ngram-size-n SPEC_NGRAM_SIZE_N] [--spec-ngram-size-m SPEC_NGRAM_SIZE_M] [--spec-ngram-min-hits SPEC_NGRAM_MIN_HITS] [--gpu-layers N] [--cpu-moe] [--mmproj MMPROJ] [--streaming-llm] [--tensor-split TENSOR_SPLIT] [--row-split] [--no-mmap] [--mlock] [--no-kv-offload] [--batch-size BATCH_SIZE] [--ubatch-size UBATCH_SIZE] [--threads THREADS] [--threads-batch THREADS_BATCH] [--numa] - [--extra-flags EXTRA_FLAGS] [--cpu] [--cpu-memory CPU_MEMORY] [--disk] [--disk-cache-dir DISK_CACHE_DIR] [--load-in-8bit] [--bf16] [--no-cache] [--trust-remote-code] - [--force-safetensors] [--no_use_fast] [--attn-implementation IMPLEMENTATION] [--load-in-4bit] [--use_double_quant] [--compute_dtype COMPUTE_DTYPE] [--quant_type QUANT_TYPE] - [--gpu-split GPU_SPLIT] [--enable-tp] [--tp-backend TP_BACKEND] [--cfg-cache] [--cpp-runner] - [--alpha_value ALPHA_VALUE] [--rope_freq_base ROPE_FREQ_BASE] [--compress_pos_emb COMPRESS_POS_EMB] [--listen] [--listen-port LISTEN_PORT] [--listen-host LISTEN_HOST] [--share] + [--parallel PARALLEL] [--fit-target FIT_TARGET] [--extra-flags EXTRA_FLAGS] [--cpu] [--cpu-memory CPU_MEMORY] [--disk] [--disk-cache-dir DISK_CACHE_DIR] [--load-in-8bit] [--bf16] + [--no-cache] [--trust-remote-code] [--force-safetensors] [--no_use_fast] [--attn-implementation IMPLEMENTATION] [--load-in-4bit] [--use_double_quant] [--compute_dtype COMPUTE_DTYPE] + [--quant_type QUANT_TYPE] [--gpu-split GPU_SPLIT] [--enable-tp] [--tp-backend TP_BACKEND] [--cfg-cache] [--listen] [--listen-port LISTEN_PORT] [--listen-host LISTEN_HOST] [--share] [--auto-launch] [--gradio-auth GRADIO_AUTH] [--gradio-auth-path GRADIO_AUTH_PATH] [--ssl-keyfile SSL_KEYFILE] [--ssl-certfile SSL_CERTFILE] [--subpath SUBPATH] [--old-colors] [--portable] [--api] [--public-api] [--public-api-id PUBLIC_API_ID] [--api-port API_PORT] [--api-key API_KEY] [--admin-key ADMIN_KEY] [--api-enable-ipv6] [--api-disable-ipv4] - [--nowebui] + [--nowebui] [--temperature N] [--dynatemp-low N] [--dynatemp-high N] [--dynatemp-exponent N] [--smoothing-factor N] [--smoothing-curve N] [--min-p N] [--top-p N] [--top-k N] + [--typical-p N] [--xtc-threshold N] [--xtc-probability N] [--epsilon-cutoff N] [--eta-cutoff N] [--tfs N] [--top-a N] [--top-n-sigma N] [--adaptive-target N] [--adaptive-decay N] + [--dry-multiplier N] [--dry-allowed-length N] [--dry-base N] [--repetition-penalty N] [--frequency-penalty N] [--presence-penalty N] [--encoder-repetition-penalty N] + [--no-repeat-ngram-size N] [--repetition-penalty-range N] [--penalty-alpha N] [--guidance-scale N] [--mirostat-mode N] [--mirostat-tau N] [--mirostat-eta N] + [--do-sample | --no-do-sample] [--dynamic-temperature | --no-dynamic-temperature] [--temperature-last | --no-temperature-last] [--sampler-priority N] [--dry-sequence-breakers N] + [--enable-thinking | --no-enable-thinking] [--reasoning-effort N] [--chat-template-file CHAT_TEMPLATE_FILE] Text Generation Web UI @@ -257,7 +261,8 @@ options: -h, --help show this help message and exit Basic settings: - --multi-user Multi-user mode. Chat histories are not saved or automatically loaded. Warning: this is likely not safe for sharing publicly. + --user-data-dir USER_DATA_DIR Path to the user data directory. Default: auto-detected. + --multi-user Multi-user mode. Chat histories are not saved or automatically loaded. Best suited for small trusted teams. --model MODEL Name of the model to load by default. --lora LORA [LORA ...] The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces. --model-dir MODEL_DIR Path to directory with all the models. @@ -280,12 +285,12 @@ Image model: Quantization method for image model. Model loader: - --loader LOADER Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav3, - TensorRT-LLM. + --loader LOADER Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav3, TensorRT- + LLM. Context and cache: - --ctx-size N, --n_ctx N, --max_seq_len N Context size in tokens. llama.cpp: 0 = auto if gpu-layers is also -1. - --cache-type N, --cache_type N KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8). + --ctx-size, --n_ctx, --max_seq_len N Context size in tokens. 0 = auto for llama.cpp (requires gpu-layers=-1), 8192 for other loaders. + --cache-type, --cache_type N KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8). Speculative decoding: --model-draft MODEL_DRAFT Path to the draft model for speculative decoding. @@ -300,7 +305,7 @@ Speculative decoding: --spec-ngram-min-hits SPEC_NGRAM_MIN_HITS Minimum n-gram hits for ngram-map speculative decoding. llama.cpp: - --gpu-layers N, --n-gpu-layers N Number of layers to offload to the GPU. -1 = auto. + --gpu-layers, --n-gpu-layers N Number of layers to offload to the GPU. -1 = auto. --cpu-moe Move the experts to the CPU (for MoE models). --mmproj MMPROJ Path to the mmproj file for vision models. --streaming-llm Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed. @@ -314,13 +319,17 @@ llama.cpp: --threads THREADS Number of threads to use. --threads-batch THREADS_BATCH Number of threads to use for batches/prompt processing. --numa Activate NUMA task allocation for llama.cpp. + --parallel PARALLEL Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set + ctx_size to 32768. + --fit-target FIT_TARGET Target VRAM margin per device for auto GPU layers, comma-separated list of values in MiB. A single value is broadcast across all devices. + Default: 1024. --extra-flags EXTRA_FLAGS Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU" Transformers/Accelerate: --cpu Use the CPU to generate text. Warning: Training on CPU is extremely slow. --cpu-memory CPU_MEMORY Maximum CPU memory in GiB. Use this for CPU offloading. --disk If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. - --disk-cache-dir DISK_CACHE_DIR Directory to save the disk cache to. Defaults to "user_data/cache". + --disk-cache-dir DISK_CACHE_DIR Directory to save the disk cache to. --load-in-8bit Load the model with 8-bit precision (using bitsandbytes). --bf16 Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. --no-cache Set use_cache to False while generating text. This reduces VRAM usage slightly, but it comes at a performance cost. @@ -341,14 +350,6 @@ ExLlamaV3: --tp-backend TP_BACKEND The backend for tensor parallelism. Valid options: native, nccl. Default: native. --cfg-cache Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader. -TensorRT-LLM: - --cpp-runner Use the ModelRunnerCpp runner, which is faster than the default ModelRunner. - -RoPE: - --alpha_value ALPHA_VALUE Positional embeddings alpha factor for NTK RoPE scaling. Use either this or compress_pos_emb, not both. - --rope_freq_base ROPE_FREQ_BASE If greater than 0, will be used instead of alpha_value. Those two are related by rope_freq_base = 10000 * alpha_value ^ (64 / 63). - --compress_pos_emb COMPRESS_POS_EMB Positional embeddings compression factor. Should be set to (context length) / (model's original context length). Equal to 1/rope_freq_scale. - Gradio: --listen Make the web UI reachable from your local network. --listen-port LISTEN_PORT The listening port that the server will use. @@ -365,7 +366,7 @@ Gradio: API: --api Enable the API extension. - --public-api Create a public URL for the API using Cloudfare. + --public-api Create a public URL for the API using Cloudflare. --public-api-id PUBLIC_API_ID Tunnel ID for named Cloudflare Tunnel. Use together with public-api option. --api-port API_PORT The listening port for the API. --api-key API_KEY API authentication key. @@ -373,28 +374,67 @@ API: --api-enable-ipv6 Enable IPv6 for the API --api-disable-ipv4 Disable IPv4 for the API --nowebui Do not launch the Gradio UI. Useful for launching the API in standalone mode. + +API generation defaults: + --temperature N Temperature + --dynatemp-low N Dynamic temperature low + --dynatemp-high N Dynamic temperature high + --dynatemp-exponent N Dynamic temperature exponent + --smoothing-factor N Smoothing factor + --smoothing-curve N Smoothing curve + --min-p N Min P + --top-p N Top P + --top-k N Top K + --typical-p N Typical P + --xtc-threshold N XTC threshold + --xtc-probability N XTC probability + --epsilon-cutoff N Epsilon cutoff + --eta-cutoff N Eta cutoff + --tfs N TFS + --top-a N Top A + --top-n-sigma N Top N Sigma + --adaptive-target N Adaptive target + --adaptive-decay N Adaptive decay + --dry-multiplier N DRY multiplier + --dry-allowed-length N DRY allowed length + --dry-base N DRY base + --repetition-penalty N Repetition penalty + --frequency-penalty N Frequency penalty + --presence-penalty N Presence penalty + --encoder-repetition-penalty N Encoder repetition penalty + --no-repeat-ngram-size N No repeat ngram size + --repetition-penalty-range N Repetition penalty range + --penalty-alpha N Penalty alpha + --guidance-scale N Guidance scale + --mirostat-mode N Mirostat mode + --mirostat-tau N Mirostat tau + --mirostat-eta N Mirostat eta + --do-sample, --no-do-sample Do sample + --dynamic-temperature, --no-dynamic-temperature Dynamic temperature + --temperature-last, --no-temperature-last Temperature last + --sampler-priority N Sampler priority + --dry-sequence-breakers N DRY sequence breakers + --enable-thinking, --no-enable-thinking Enable thinking + --reasoning-effort N Reasoning effort + --chat-template-file CHAT_TEMPLATE_FILE Path to a chat template file (.jinja, .jinja2, or .yaml) to use as the default instruction template for API requests. Overrides the model's + built-in template. ``` ## Downloading models -Models should be placed in the folder `text-generation-webui/user_data/models`. They are usually downloaded from [Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads&search=gguf). +1. Download a GGUF model file from [Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads&search=gguf). +2. Place it in the `user_data/models` folder. -To check if a GGUF model will fit in your hardware before downloading it, you can use this tool I created: +That's it. The UI will detect it automatically. -[Accurate GGUF VRAM Calculator](https://huggingface.co/spaces/oobabooga/accurate-gguf-vram-calculator) +To check what will fit your GPU, you can use the [VRAM Calculator](https://huggingface.co/spaces/oobabooga/accurate-gguf-vram-calculator). -* GGUF models are a single file and should be placed directly into `user_data/models`. Example: +
+Other model types (Transformers, EXL3) -``` -text-generation-webui -└── user_data - └── models - └── llama-2-13b-chat.Q4_K_M.gguf -``` - -* The remaining model types (like 16-bit Transformers models and EXL3 models) are made of several files and must be placed in a subfolder. Example: +Models that consist of multiple files (like 16-bit Transformers models and EXL3 models) should be placed in a subfolder inside `user_data/models`: ``` text-generation-webui @@ -404,31 +444,18 @@ text-generation-webui ├── config.json ├── generation_config.json ├── model-00001-of-00004.safetensors - ├── model-00002-of-00004.safetensors - ├── model-00003-of-00004.safetensors - ├── model-00004-of-00004.safetensors - ├── model.safetensors.index.json - ├── special_tokens_map.json + ├── ... ├── tokenizer_config.json └── tokenizer.json ``` -In both cases, you can use the "Model" tab of the UI to download the model from Hugging Face automatically. It is also possible to download it via the command-line with: - -``` -python download-model.py organization/model -``` - -Run `python download-model.py --help` to see all the options. +These formats require the one-click installer (not the portable build). +
## Documentation https://github.com/oobabooga/text-generation-webui/wiki -## Google Colab notebook - -https://colab.research.google.com/github/oobabooga/text-generation-webui/blob/main/Colab-TextGen-GPU.ipynb - ## Community https://www.reddit.com/r/Oobabooga/ diff --git a/cmd_windows.bat b/cmd_windows.bat index 787b4335..b0540bd8 100755 --- a/cmd_windows.bat +++ b/cmd_windows.bat @@ -21,6 +21,7 @@ set INSTALL_ENV_DIR=%cd%\installer_files\env set PYTHONNOUSERSITE=1 set PYTHONPATH= set PYTHONHOME= +set PYTHONUTF8=1 set "CUDA_PATH=%INSTALL_ENV_DIR%" set "CUDA_HOME=%CUDA_PATH%" diff --git a/css/chat_style-Dark.css b/css/chat_style-Dark.css index 6a4784cc..01a168ab 100644 --- a/css/chat_style-Dark.css +++ b/css/chat_style-Dark.css @@ -2,6 +2,7 @@ display: grid; align-items: start; grid-template-columns: 60px minmax(0, 1fr); + width: min(100%, calc(724px + 60px)); padding-bottom: 22px; padding-top: 6px; font-size: 18px; @@ -91,9 +92,6 @@ } .message-body p { - margin-bottom: 0 !important; - font-size: 16px !important; - line-height: 1.5 !important; color: #e0e0e0 !important; /* Light color for text */ } @@ -122,7 +120,7 @@ } .message-body p { - font-size: 14px !important; /* Smaller text for mobile */ + font-size: 14px !important; } .username { diff --git a/css/chat_style-TheEncrypted777.css b/css/chat_style-TheEncrypted777.css index fbd47072..9543a3df 100644 --- a/css/chat_style-TheEncrypted777.css +++ b/css/chat_style-TheEncrypted777.css @@ -4,6 +4,7 @@ display: grid; align-items: start; grid-template-columns: 60px minmax(0, 1fr); + width: min(100%, calc(724px + 60px + 90px)); padding-bottom: 21px; padding-top: 7px; font-size: 18px; @@ -86,10 +87,8 @@ border-radius: 20px; } -.message-body p { - margin-bottom: 0 !important; +.message-body p, .message-body li { font-size: 18px !important; - line-height: 1.428571429 !important; color: rgb(243 244 246) !important; text-shadow: 2px 2px 2px rgb(0 0 0); font-weight: 500; @@ -127,7 +126,7 @@ padding-left: 0; } - .message-body p { + .message-body p, .message-body li { font-size: 16px !important; } diff --git a/css/chat_style-cai-chat-square.css b/css/chat_style-cai-chat-square.css index 291a1209..8254a4ec 100644 --- a/css/chat_style-cai-chat-square.css +++ b/css/chat_style-cai-chat-square.css @@ -19,4 +19,5 @@ padding-bottom: 1.5em; padding-top: 0.5em; grid-template-columns: 70px minmax(0, 1fr); + width: min(100%, calc(724px + 70px)); } diff --git a/css/chat_style-cai-chat.css b/css/chat_style-cai-chat.css index b06b1269..66d2816d 100644 --- a/css/chat_style-cai-chat.css +++ b/css/chat_style-cai-chat.css @@ -2,6 +2,7 @@ display: grid; align-items: start; grid-template-columns: 60px minmax(0, 1fr); + width: min(100%, calc(724px + 60px)); padding-bottom: 1.5em; padding-top: 0.5em; font-size: 15px; @@ -46,16 +47,10 @@ border-radius: 20px; } -.message-body p { - font-size: 15px !important; - line-height: 22.5px !important; +.message-body p, .message-body li { font-weight: 500; } -.message-body p, .chat .message-body ul, .chat .message-body ol { - margin-bottom: 10px !important; -} - .dark .message-body p em { color: rgb(138 138 138) !important; } diff --git a/css/chat_style-messenger.css b/css/chat_style-messenger.css index 70fd6d4a..fd9b5b70 100644 --- a/css/chat_style-messenger.css +++ b/css/chat_style-messenger.css @@ -1,4 +1,5 @@ .message { + width: min(100%, calc(724px + 60px)); padding-bottom: 22px; padding-top: 3px; font-size: 15px; @@ -60,8 +61,10 @@ text-align: right; } -.dark .circle-bot + .text div, .dark .circle-bot + .text * { - color: #000; +.dark .circle-bot + .text div, .dark .circle-bot + .text *, +.dark .chat .message .circle-bot + .text .message-body :is(h1, h2, h3, h4, h5, h6), +.dark .chat .message .circle-bot + .text .message-body a { + color: #000 !important; } .text { @@ -76,19 +79,14 @@ font-weight: bold; } -.message-body { -} - .message-body img { max-width: 300px; max-height: 300px; border-radius: 20px; } -.message-body p { - margin-bottom: 0 !important; +.message-body p, .message-body li { font-size: 15px !important; - line-height: 1.428571429 !important; font-weight: 500; } diff --git a/css/chat_style-wpp.css b/css/chat_style-wpp.css index b2ac4d2a..65e253d9 100644 --- a/css/chat_style-wpp.css +++ b/css/chat_style-wpp.css @@ -1,5 +1,6 @@ .message { display: block; + width: min(100%, 724px); padding-top: 0; padding-bottom: 21px; font-size: 15px; @@ -77,14 +78,8 @@ border-radius: 12px; } -.message-body p { +.message-body p, .message-body li { font-size: 15px !important; - line-height: 1.4 !important; - font-weight: 400; -} - -.message-body p:first-child { - margin-top: 0 !important; } .dark .message-body p em { @@ -100,6 +95,3 @@ margin-top: 8px; } -.message-body p, .chat .message-body ul, .chat .message-body ol { - margin-bottom: 10px !important; -} diff --git a/css/html_instruct_style.css b/css/html_instruct_style.css index 72a148c3..458feafc 100644 --- a/css/html_instruct_style.css +++ b/css/html_instruct_style.css @@ -78,7 +78,7 @@ .chat .user-message .text, .chat .assistant-message .text { - max-width: 700px; + max-width: 724px; margin-left: auto; margin-right: auto; } diff --git a/css/main.css b/css/main.css index 2d19b5c8..25ae15b1 100644 --- a/css/main.css +++ b/css/main.css @@ -400,7 +400,6 @@ audio { } .chat .message { - width: min(100%, 48rem); margin-left: auto; margin-right: auto; text-align: start; @@ -431,10 +430,19 @@ audio { font-size: 16px; } -.dark .message-body :is(h1, h2, h3, h4, h5, h6) { +.dark .message-body h1, +.dark .message-body h2, +.dark .message-body h3, +.dark .message-body h4, +.dark .message-body h5, +.dark .message-body h6 { color: white !important; } +.dark .message-body blockquote { + border-left-color: rgb(255 255 255 / 30%); +} + .message-body h1 { font-weight: 800; font-size: 2.25em; @@ -831,9 +839,20 @@ audio { } } -.message-body ol, .message-body ul { +.message-body p, .message-body li { + line-height: 1.75 !important; +} + +.message-body p, .message-body ul, .message-body ol { + margin: 1.25em 0 !important; +} + +.message-body :is(p, ul, ol):first-child { margin-top: 0 !important; - margin-bottom: 1.25em !important; +} + +.message-body :is(p, ul, ol):last-child { + margin-bottom: 0 !important; } /* ---------------------------------------------- @@ -1003,6 +1022,49 @@ audio { padding-right: 0.5rem; } +#new-chat-wrapper { + display: contents; +} + +.new-chat-arrow { + cursor: pointer; + position: relative; + padding: 0; + margin-right: -15px; + height: 39.594px; + display: flex; + align-items: center; +} + +.new-chat-menu { + display: none; + position: absolute; + top: 0; + left: 0; + padding-top: 1.2em; + z-index: var(--layer-top); + white-space: nowrap; +} + +.new-chat-arrow:hover .new-chat-menu { + display: block; +} + +.new-chat-menu-item { + cursor: pointer; + padding: var(--size-2); + background: var(--background-fill-primary); + box-shadow: var(--shadow-drop-lg); + border-radius: var(--container-radius); + color: var(--body-text-color); + font-size: var(--text-md); + font-weight: var(--button-large-text-weight); +} + +.new-chat-menu-item:hover { + background: var(--background-fill-secondary); +} + #past-chats-row, #chat-controls { width: 260px; @@ -1373,7 +1435,6 @@ audio { overflow-wrap: break-word; max-height: 250px; overflow-y: scroll; - contain: layout; } .chat .message-body .thinking-content p, @@ -1662,7 +1723,7 @@ button:focus { .chat-parent { /* Optimize for scrolling performance */ will-change: scroll-position; - contain: layout style paint; + contain: style paint; /* Ensure GPU acceleration */ transform: translateZ(0); @@ -1802,6 +1863,15 @@ table { border-collapse: collapse; } +.table-wrapper { + overflow-x: auto; +} + +.message-body :is(td, th) { + word-break: normal; + overflow-wrap: normal; +} + table, tr, td, th, thead { border: 0; } @@ -1814,3 +1884,86 @@ tr + tr th { border-top: 1px solid; } thead + tbody tr:first-child td, thead + tbody tr:first-child th { border-top: 1px solid; } + +/* ------------------------------------------------ + Tools CheckboxGroup - vertical DragDrop-like style + ------------------------------------------------ */ + +/* "Refresh list" link in the Tools label */ +.tools-refresh-link { + cursor: pointer; +} + +/* Checkbox list container */ +#tools-group { + padding: 0 !important; + border-width: 0 !important; + background: transparent !important; + min-height: 0 !important; +} + +#tools-group .wrap { + display: flex; + flex-direction: column; + flex-wrap: nowrap; + gap: 4px; + padding: 0; + margin-top: var(--spacing-lg); + max-height: 350px; + overflow-y: auto; +} + +/* Pretty scrollbar for the tools list */ +#tools-group .wrap::-webkit-scrollbar { + width: 8px; + height: 8px; +} + +#tools-group .wrap::-webkit-scrollbar-track { + background: transparent; +} + +#tools-group .wrap::-webkit-scrollbar-thumb, +#tools-group .wrap::-webkit-scrollbar-thumb:hover { + background: var(--neutral-300); + border-radius: 30px; +} + +.dark #tools-group .wrap::-webkit-scrollbar-thumb, +.dark #tools-group .wrap::-webkit-scrollbar-thumb:hover { + background: rgb(255 255 255 / 6.25%); + border-radius: 10px; +} + +#tools-group .wrap::-webkit-scrollbar-corner { + background: transparent; +} + +/* Each checkbox item */ +#tools-group label { + display: flex; + align-items: center; + gap: 8px; + padding: 5px 8px; + border-radius: var(--radius-sm, 4px); + background: var(--block-background-fill); + border: 1px solid var(--border-color-primary); + color: var(--body-text-color); + font-size: var(--input-text-size); + font-weight: var(--input-text-weight); + cursor: pointer; + user-select: none; + transition: border-color 0.15s ease, background 0.15s ease; + box-shadow: none; +} + +#tools-group label:hover { + border-color: var(--input-border-color-focus); +} + +#tools-group label span { + flex: 1; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} diff --git a/docs/04 - Model Tab.md b/docs/04 - Model Tab.md index 4d5ae645..744970ac 100644 --- a/docs/04 - Model Tab.md +++ b/docs/04 - Model Tab.md @@ -41,9 +41,6 @@ Options: * **cpu_memory**: Maximum CPU memory in GiB to use for CPU offloading via the accelerate library. Whatever doesn't fit in the GPU or CPU will go to a disk cache if the "disk" checkbox is enabled. * **compute_dtype**: Used when "load_in_4bit" is checked. I recommend leaving the default value. * **quant_type**: Used when "load_in_4bit" is checked. I recommend leaving the default value. -* **alpha_value**: Used to extend the context length of a model with a minor loss in quality. I have measured 1.75 to be optimal for 1.5x context, and 2.5 for 2x context. That is, with alpha = 2.5 you can make a model with 4096 context length go to 8192 context length. -* **rope_freq_base**: Originally another way to write "alpha_value", it ended up becoming a necessary parameter for some models like CodeLlama, which was fine-tuned with this set to 1000000 and hence needs to be loaded with it set to 1000000 as well. -* **compress_pos_emb**: The first and original context-length extension method, discovered by [kaiokendev](https://kaiokendev.github.io/til). When set to 2, the context length is doubled, 3 and it's tripled, etc. It should only be used for models that have been fine-tuned with this parameter set to different than 1. For models that have not been tuned to have greater context length, alpha_value will lead to a smaller accuracy loss. * **attn_implementation**: Choose the attention implementation. Valid options: `sdpa`, `eager`, `flash_attention_2`. The default (`sdpa`) works well in most cases; `flash_attention_2` may be useful for training. * **cpu**: Loads the model in CPU mode using Pytorch. The model will be loaded in 32-bit precision, so a lot of RAM will be used. CPU inference with transformers is older than llama.cpp and it works, but it's a lot slower. Note: this parameter has a different interpretation in the llama.cpp loader (see above). * **load_in_8bit**: Load the model in 8-bit precision using bitsandbytes. The 8-bit kernel in that library has been optimized for training and not inference, so load_in_8bit is slower than load_in_4bit (but more accurate). diff --git a/docs/05 - Training Tab.md b/docs/05 - Training Tab.md index 902693e6..0bfc59aa 100644 --- a/docs/05 - Training Tab.md +++ b/docs/05 - Training Tab.md @@ -4,7 +4,7 @@ A LoRA is tied to a specific model architecture — a LoRA trained on Llama 3 8B ### Quick Start -1. Load your base model (no LoRAs loaded). +1. Load your base model with the **Transformers** loader (no LoRAs loaded). 2. Open the **Training** tab > **Train LoRA**. 3. Pick a dataset and configure parameters (see [below](#parameters)). 4. Click **Start LoRA Training** and monitor the [loss](#loss). diff --git a/docs/Tool Calling Tutorial.md b/docs/Tool Calling Tutorial.md new file mode 100644 index 00000000..801e9d78 --- /dev/null +++ b/docs/Tool Calling Tutorial.md @@ -0,0 +1,159 @@ +## Supported models + +The following models are supported: + +- Qwen 3.5 +- GPT-OSS +- Mistral Small / Devstral +- DeepSeek V3 +- Kimi-K2 +- MiniMax-M2.5 +- GLM-5 +- Llama 4 + +Other models that output tool calls as JSON (inside XML tags, code blocks, or plain JSON) are also supported through a generic fallback parser. + +## Tool calling in the UI + +### 1. Load a model with tool-calling support + +Load a model with tool-calling support from the Model tab. + +### 2. Select tools + +In the chat sidebar, check the tools you want the model to use: + +- **web_search** -- Search the web using DuckDuckGo. +- **fetch_webpage** -- Fetch the content of a URL. +- **calculate** -- Evaluate math expressions. +- **get_datetime** -- Get the current date and time. +- **roll_dice** -- Roll dice. + +### 3. Chat + +Send a message as usual. When the model decides it needs a tool, it will call it automatically. You will see each tool call and its result in a collapsible accordion inside the chat message. + +The model may call multiple tools in sequence before giving its final answer. + +## Writing custom tools + +Each tool is a single `.py` file in `user_data/tools/`. It needs two things: + +1. A `tool` dictionary that describes the function (name, description, parameters). +2. An `execute(arguments)` function that runs it and returns the result. + +Here is a minimal example (`user_data/tools/get_datetime.py`): + +```python +from datetime import datetime + +tool = { + "type": "function", + "function": { + "name": "get_datetime", + "description": "Get the current date and time.", + "parameters": { + "type": "object", + "properties": {}, + } + } +} + + +def execute(arguments): + now = datetime.now() + return {"date": now.strftime("%Y-%m-%d"), "time": now.strftime("%I:%M %p")} +``` + +An example with parameters (`user_data/tools/roll_dice.py`): + +```python +import random + +tool = { + "type": "function", + "function": { + "name": "roll_dice", + "description": "Roll one or more dice with the specified number of sides.", + "parameters": { + "type": "object", + "properties": { + "count": {"type": "integer", "description": "Number of dice to roll.", "default": 1}, + "sides": {"type": "integer", "description": "Number of sides per die.", "default": 20}, + }, + } + } +} + + +def execute(arguments): + count = max(1, min(arguments.get("count", 1), 1000)) + sides = max(2, min(arguments.get("sides", 20), 1000)) + rolls = [random.randint(1, sides) for _ in range(count)] + return {"rolls": rolls, "total": sum(rolls)} +``` + +You can open the built-in tools in `user_data/tools/` for more examples. + +## Tool calling over the API + +Tool calling over the API follows the [OpenAI API](https://platform.openai.com/docs/guides/function-calling) convention. Define your tools, send them with your messages, and handle tool calls in a loop until the model gives a final answer. + +```python +import json +import requests + +url = "http://127.0.0.1:5000/v1/chat/completions" + +tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a given location.", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City name"}, + }, + "required": ["location"] + } + } + } +] + + +def execute_tool(name, arguments): + if name == "get_weather": + return {"temperature": "14°C", "condition": "partly cloudy"} + return {"error": f"Unknown tool: {name}"} + + +messages = [{"role": "user", "content": "What's the weather like in Paris?"}] + +for _ in range(10): + response = requests.post(url, json={"messages": messages, "tools": tools}).json() + choice = response["choices"][0] + + if choice["finish_reason"] == "tool_calls": + messages.append({ + "role": "assistant", + "content": choice["message"]["content"], + "tool_calls": choice["message"]["tool_calls"], + }) + + for tool_call in choice["message"]["tool_calls"]: + name = tool_call["function"]["name"] + arguments = json.loads(tool_call["function"]["arguments"]) + result = execute_tool(name, arguments) + print(f"Tool call: {name}({arguments}) => {result}") + + messages.append({ + "role": "tool", + "tool_call_id": tool_call["id"], + "content": json.dumps(result), + }) + else: + print(f"\nAssistant: {choice['message']['content']}") + break +``` diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index 8ba031c1..fc17a19a 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -11,8 +11,10 @@ from pydantic import ValidationError from extensions.openai.errors import InvalidRequestError from extensions.openai.typing import ToolDefinition -from extensions.openai.utils import debug_msg, getToolCallId, parseToolCall +from extensions.openai.utils import debug_msg +from modules.tool_parsing import get_tool_call_id, parse_tool_call, detect_tool_call_format from modules import shared +from modules.reasoning import extract_reasoning from modules.chat import ( generate_chat_prompt, generate_chat_reply, @@ -37,17 +39,114 @@ def load_chat_template_file(filepath): return text -def convert_logprobs_to_tiktoken(model, logprobs): - # more problems than it's worth. - # try: - # encoder = tiktoken.encoding_for_model(model) - # # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall. - # return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()]) - # except KeyError: - # # assume native tokens if we can't find the tokenizer - # return logprobs +def _get_raw_logprob_entries(offset=0): + """Get raw logprob entries from llama.cpp/ExLlamav3 backend, starting from offset. - return logprobs + Returns (new_entries, new_offset). + """ + if not hasattr(shared.model, 'last_completion_probabilities') or not shared.model.last_completion_probabilities: + return [], offset + + all_entries = shared.model.last_completion_probabilities + new_entries = all_entries[offset:] + return new_entries, len(all_entries) + + +def _dict_to_logprob_entries(token_dict): + """Convert a flat {token: logprob} dict (from LogprobProcessor) to raw entry format.""" + if not token_dict: + return [] + + return [{"top_logprobs": [{"token": t, "logprob": lp} for t, lp in token_dict.items()]}] + + +def _parse_entry_top(entry): + """Extract the top logprobs list from a raw entry, handling both key names.""" + return entry.get('top_logprobs', entry.get('top_probs', [])) + + +def format_chat_logprobs(entries): + """Format logprob entries into OpenAI chat completions logprobs format. + + Output: {"content": [{"token", "logprob", "bytes", "top_logprobs": [...]}]} + """ + if not entries: + return None + + content = [] + for entry in entries: + top = _parse_entry_top(entry) + if not top: + continue + + chosen = top[0] + token_str = chosen.get('token', '') + token_logprob = chosen.get('logprob', chosen.get('prob', 0)) + + top_list = [] + for item in top: + t = item.get('token', '') + lp = item.get('logprob', item.get('prob', 0)) + top_list.append({ + "token": t, + "logprob": lp, + "bytes": list(t.encode('utf-8')) if t else None + }) + + content.append({ + "token": token_str, + "logprob": token_logprob, + "bytes": list(token_str.encode('utf-8')) if token_str else None, + "top_logprobs": top_list + }) + + return {"content": content, "refusal": None} if content else None + + +def format_completion_logprobs(entries): + """Format logprob entries into OpenAI completions logprobs format. + + Output: {"tokens", "token_logprobs", "top_logprobs": [{token: prob}], "text_offset"} + """ + if not entries: + return None + + tokens = [] + token_logprobs = [] + top_logprobs = [] + text_offset = [] + offset = 0 + + for entry in entries: + top = _parse_entry_top(entry) + if not top: + continue + + chosen = top[0] + token_str = chosen.get('token', '') + token_logprob = chosen.get('logprob', chosen.get('prob', 0)) + + tokens.append(token_str) + token_logprobs.append(token_logprob) + text_offset.append(offset) + offset += len(token_str) + + top_dict = {} + for item in top: + t = item.get('token', '') + lp = item.get('logprob', item.get('prob', 0)) + top_dict[t] = lp + top_logprobs.append(top_dict) + + if not tokens: + return None + + return { + "tokens": tokens, + "token_logprobs": token_logprobs, + "top_logprobs": top_logprobs, + "text_offset": text_offset + } def process_parameters(body, is_legacy=False): @@ -72,7 +171,16 @@ def process_parameters(body, is_legacy=False): elif isinstance(body['stop'], list): generate_params['custom_stopping_strings'] = body['stop'] - if shared.args.loader != 'llama.cpp': + # Resolve logprobs: for chat completions, logprobs is a bool and the count + # comes from top_logprobs. Normalize to an int for all backends. + logprobs = body.get('logprobs', None) + top_logprobs = body.get('top_logprobs', None) + if logprobs is True: + logprobs = max(top_logprobs, 1) if top_logprobs is not None else 5 + generate_params['logprobs'] = logprobs + + # For llama.cpp and ExLlamav3 native, logit_bias and logprobs are forwarded natively + if shared.args.loader not in ('llama.cpp', 'ExLlamav3'): from transformers import LogitsProcessorList from modules.transformers_loader import ( @@ -85,13 +193,9 @@ def process_parameters(body, is_legacy=False): if logit_bias: # {str: float, ...} logits_processor = [LogitsBiasProcessor(logit_bias)] - logprobs = None # coming to chat eventually - if 'logprobs' in body: - logprobs = body.get('logprobs', 0) # maybe cap at topk? don't clamp 0-5. + if logprobs is not None and logprobs > 0: generate_params['logprob_proc'] = LogprobProcessor(logprobs) logits_processor.extend([generate_params['logprob_proc']]) - else: - logprobs = None if logits_processor: # requires logits_processor support generate_params['logits_processor'] = LogitsProcessorList(logits_processor) @@ -137,12 +241,14 @@ def convert_history(history): user_input = "" user_input_last = True system_message = "" + seen_non_system = False for entry in history: content = entry["content"] role = entry["role"] if role == "user": + seen_non_system = True # Extract text content (images handled by model-specific code) content = process_multimodal_content(content) user_input = content @@ -154,6 +260,7 @@ def convert_history(history): current_message = content elif role == "assistant": + seen_non_system = True meta = {} tool_calls = entry.get("tool_calls") if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0: @@ -170,13 +277,22 @@ def convert_history(history): else: chat_dialogue.append(['', current_reply, '', meta]) elif role == "tool": + seen_non_system = True user_input_last = False meta = {} if "tool_call_id" in entry: meta["tool_call_id"] = entry["tool_call_id"] chat_dialogue.append(['', '', content, meta]) - elif role == "system": - system_message += f"\n{content}" if system_message else content + elif role in ("system", "developer"): + if not seen_non_system: + # Leading system messages go to custom_system_message (placed at top) + system_message += f"\n{content}" if system_message else content + else: + # Mid-conversation system messages: preserve position in history + if current_message: + chat_dialogue.append([current_message, '', '', {}]) + current_message = "" + chat_dialogue.append([content, '', '', {"role": "system"}]) if not user_input_last: user_input = "" @@ -202,6 +318,10 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p if 'tools' in body and body['tools'] is not None and isinstance(body['tools'], list) and len(body['tools']) > 0: tools = validateTools(body['tools']) # raises InvalidRequestError if validation fails + tool_choice = body.get('tool_choice', None) + if tool_choice == "none": + tools = None # Disable tool detection entirely + messages = body['messages'] for m in messages: if 'role' not in m: @@ -293,29 +413,55 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p requested_model = generate_params.pop('model') logprob_proc = generate_params.pop('logprob_proc', None) + if logprob_proc: + logprob_proc.token_alternatives_history.clear() + chat_logprobs_offset = [0] # mutable for closure access in streaming - def chat_streaming_chunk(content, chunk_tool_calls=None): + def chat_streaming_chunk(content=None, chunk_tool_calls=None, include_role=False, reasoning_content=None): # begin streaming + delta = {} + if include_role: + delta['role'] = 'assistant' + delta['refusal'] = None + if content is not None: + delta['content'] = content + if reasoning_content is not None: + delta['reasoning_content'] = reasoning_content + if chunk_tool_calls: + delta['tool_calls'] = chunk_tool_calls + chunk = { "id": cmpl_id, "object": object_type, "created": created_time, "model": shared.model_name, + "system_fingerprint": None, resp_list: [{ "index": 0, "finish_reason": None, - "delta": {'role': 'assistant', 'content': content, 'tool_calls': chunk_tool_calls}, + "delta": delta, + "logprobs": None, }], } - if logprob_proc: # not official for chat yet - top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives) - chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]} - # else: - # chunk[resp_list][0]["logprobs"] = None + if logprob_proc: + entries = _dict_to_logprob_entries(logprob_proc.token_alternatives) + formatted = format_chat_logprobs(entries) + if formatted: + chunk[resp_list][0]["logprobs"] = formatted + elif shared.args.loader in ('llama.cpp', 'ExLlamav3'): + entries, chat_logprobs_offset[0] = _get_raw_logprob_entries(chat_logprobs_offset[0]) + if entries: + formatted = format_chat_logprobs(entries) + if formatted: + chunk[resp_list][0]["logprobs"] = formatted return chunk + # Check if usage should be included in streaming chunks per OpenAI spec + stream_options = body.get('stream_options') + include_usage = bool(stream_options) and bool(stream_options.get('include_usage') if isinstance(stream_options, dict) else getattr(stream_options, 'include_usage', False)) + # generate reply ####################################### if prompt_only: prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_) @@ -323,75 +469,133 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p return if stream: - yield chat_streaming_chunk('') + chunk = chat_streaming_chunk('', include_role=True) + if include_usage: + chunk['usage'] = None + yield chunk generator = generate_chat_reply( user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False) answer = '' seen_content = '' + seen_reasoning = '' tool_calls = [] end_last_tool_call = 0 supported_tools = [x["function"]["name"] for x in tools] if tools is not None else None + _tool_parsers = None + + # Filter supported_tools when tool_choice specifies a particular function + if supported_tools and isinstance(tool_choice, dict): + specified_func = tool_choice.get("function", {}).get("name") + if specified_func and specified_func in supported_tools: + supported_tools = [specified_func] + + if supported_tools is not None: + _template_str = generate_params.get('instruction_template_str', '') if generate_params.get('mode') == 'instruct' else generate_params.get('chat_template_str', '') + _tool_parsers, _, _ = detect_tool_call_format(_template_str) for a in generator: answer = a['internal'][-1][1] if supported_tools is not None: - tool_call = parseToolCall(answer[end_last_tool_call:], supported_tools) if len(answer) > 0 else [] + tool_call = parse_tool_call(answer[end_last_tool_call:], supported_tools, parsers=_tool_parsers) if len(answer) > 0 else [] if len(tool_call) > 0: for tc in tool_call: - tc["id"] = getToolCallId() - tc["index"] = len(tool_calls) + tc["id"] = get_tool_call_id() + if stream: + tc["index"] = len(tool_calls) tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"]) tool_calls.append(tc) end_last_tool_call = len(answer) - if stream: - len_seen = len(seen_content) - new_content = answer[len_seen:] - - if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet. - continue - - chunk = chat_streaming_chunk(new_content) - - seen_content = answer - yield chunk - - # stop generation if tool_calls were generated previously + # Stop generation before streaming content if tool_calls were detected, + # so that raw tool markup is not sent as content deltas. if len(tool_calls) > 0: break + if stream: + # Strip reasoning/thinking blocks so only final content is streamed. + # Reasoning is emitted separately as reasoning_content deltas. + reasoning, content = extract_reasoning(answer) + if reasoning is not None: + new_reasoning = reasoning[len(seen_reasoning):] + new_content = content[len(seen_content):] + else: + new_reasoning = None + new_content = answer[len(seen_content):] + + if (not new_content and not new_reasoning) or chr(0xfffd) in (new_content or '') + (new_reasoning or ''): + continue + + chunk = chat_streaming_chunk( + content=new_content if new_content else None, + reasoning_content=new_reasoning if new_reasoning else None, + ) + if include_usage: + chunk['usage'] = None + + if reasoning is not None: + seen_reasoning = reasoning + seen_content = content + else: + seen_content = answer + yield chunk + token_count = shared.model.last_prompt_token_count if hasattr(shared.model, 'last_prompt_token_count') else 0 completion_token_count = len(encode(answer)[0]) - stop_reason = "stop" if len(tool_calls) > 0: stop_reason = "tool_calls" - if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']: + elif token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']: stop_reason = "length" + else: + stop_reason = "stop" if stream: - chunk = chat_streaming_chunk('', tool_calls) + chunk = chat_streaming_chunk(chunk_tool_calls=tool_calls) chunk[resp_list][0]['finish_reason'] = stop_reason - chunk['usage'] = { + usage = { "prompt_tokens": token_count, "completion_tokens": completion_token_count, "total_tokens": token_count + completion_token_count } - yield chunk + if include_usage: + chunk['usage'] = None + yield chunk + # Separate usage-only chunk with choices: [] per OpenAI spec + yield { + "id": cmpl_id, + "object": object_type, + "created": created_time, + "model": shared.model_name, + "system_fingerprint": None, + resp_list: [], + "usage": usage + } + else: + yield chunk else: + reasoning, content = extract_reasoning(answer) + message = { + "role": "assistant", + "refusal": None, + "content": None if tool_calls else content, + **({"reasoning_content": reasoning} if reasoning else {}), + **({"tool_calls": tool_calls} if tool_calls else {}), + } resp = { "id": cmpl_id, "object": object_type, "created": created_time, "model": shared.model_name, + "system_fingerprint": None, resp_list: [{ "index": 0, "finish_reason": stop_reason, - "message": {"role": "assistant", "content": answer, **({"tool_calls": tool_calls} if tool_calls else {})}, + "message": message, + "logprobs": None, }], "usage": { "prompt_tokens": token_count, @@ -399,11 +603,19 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p "total_tokens": token_count + completion_token_count } } - if logprob_proc: # not official for chat yet - top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives) - resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]} - # else: - # resp[resp_list][0]["logprobs"] = None + if logprob_proc: + all_entries = [] + for alt in logprob_proc.token_alternatives_history: + all_entries.extend(_dict_to_logprob_entries(alt)) + formatted = format_chat_logprobs(all_entries) + if formatted: + resp[resp_list][0]["logprobs"] = formatted + elif shared.args.loader in ('llama.cpp', 'ExLlamav3'): + raw = getattr(shared.model, 'last_completion_probabilities', None) + if raw: + formatted = format_chat_logprobs(raw) + if formatted: + resp[resp_list][0]["logprobs"] = formatted yield resp @@ -411,7 +623,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_event=None): object_type = 'text_completion' created_time = int(time.time()) - cmpl_id = "conv-%d" % (int(time.time() * 1000000000)) + cmpl_id = "cmpl-%d" % (int(time.time() * 1000000000)) resp_list = 'data' if is_legacy else 'choices' prompt_str = 'context' if is_legacy else 'prompt' @@ -445,6 +657,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e generate_params['stop_event'] = stop_event requested_model = generate_params.pop('model') logprob_proc = generate_params.pop('logprob_proc', None) + if logprob_proc: + logprob_proc.token_alternatives_history.clear() suffix = body['suffix'] if body['suffix'] else '' echo = body['echo'] @@ -456,6 +670,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e logger.info(f"Found {len(raw_images)} image(s) in request.") generate_params['raw_images'] = raw_images + n_completions = body.get('n', 1) or 1 + if not stream: prompt_arg = body[prompt_str] @@ -469,6 +685,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e resp_list_data = [] total_completion_token_count = 0 total_prompt_token_count = 0 + choice_index = 0 for idx, prompt in enumerate(prompt_arg, start=0): if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], int): @@ -483,37 +700,59 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e prompt = decode(prompt)[0] prefix = prompt if echo else '' - - # generate reply ####################################### - debug_msg({'prompt': prompt, 'generate_params': generate_params}) - generator = generate_reply(prompt, generate_params, is_chat=False) - answer = '' - - for a in generator: - answer = a - token_count = len(encode(prompt)[0]) total_prompt_token_count += token_count - completion_token_count = len(encode(answer)[0]) - total_completion_token_count += completion_token_count - stop_reason = "stop" - if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens: - stop_reason = "length" - respi = { - "index": idx, - "finish_reason": stop_reason, - "text": prefix + answer + suffix, - "logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None, - } + original_seed = generate_params.get('seed', -1) + for _n in range(n_completions): + # Increment seed for each completion to ensure diversity (matches llama.cpp native behavior) + if original_seed >= 0: + generate_params['seed'] = original_seed + _n - resp_list_data.extend([respi]) + if logprob_proc: + logprob_proc.token_alternatives_history.clear() + + # generate reply ####################################### + debug_msg({'prompt': prompt, 'generate_params': generate_params}) + generator = generate_reply(prompt, generate_params, is_chat=False) + answer = '' + + for a in generator: + answer = a + + completion_token_count = len(encode(answer)[0]) + total_completion_token_count += completion_token_count + stop_reason = "stop" + if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens: + stop_reason = "length" + + if logprob_proc: + all_entries = [] + for alt in logprob_proc.token_alternatives_history: + all_entries.extend(_dict_to_logprob_entries(alt)) + completion_logprobs = format_completion_logprobs(all_entries) + elif shared.args.loader in ('llama.cpp', 'ExLlamav3'): + raw = getattr(shared.model, 'last_completion_probabilities', None) + completion_logprobs = format_completion_logprobs(raw) + else: + completion_logprobs = None + + respi = { + "index": choice_index, + "finish_reason": stop_reason, + "text": prefix + answer + suffix, + "logprobs": completion_logprobs, + } + + resp_list_data.append(respi) + choice_index += 1 resp = { "id": cmpl_id, "object": object_type, "created": created_time, "model": shared.model_name, + "system_fingerprint": None, resp_list: resp_list_data, "usage": { "prompt_tokens": total_prompt_token_count, @@ -538,24 +777,41 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e prefix = prompt if echo else '' token_count = len(encode(prompt)[0]) + # Check if usage should be included in streaming chunks per OpenAI spec + stream_options = body.get('stream_options') + include_usage = bool(stream_options) and bool(stream_options.get('include_usage') if isinstance(stream_options, dict) else getattr(stream_options, 'include_usage', False)) + cmpl_logprobs_offset = [0] # mutable for closure access in streaming + def text_streaming_chunk(content): # begin streaming + if logprob_proc: + chunk_logprobs = format_completion_logprobs(_dict_to_logprob_entries(logprob_proc.token_alternatives)) + elif shared.args.loader in ('llama.cpp', 'ExLlamav3'): + entries, cmpl_logprobs_offset[0] = _get_raw_logprob_entries(cmpl_logprobs_offset[0]) + chunk_logprobs = format_completion_logprobs(entries) if entries else None + else: + chunk_logprobs = None + chunk = { "id": cmpl_id, "object": object_type, "created": created_time, "model": shared.model_name, + "system_fingerprint": None, resp_list: [{ "index": 0, "finish_reason": None, "text": content, - "logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None, + "logprobs": chunk_logprobs, }], } return chunk - yield text_streaming_chunk(prefix) + chunk = text_streaming_chunk(prefix) + if include_usage: + chunk['usage'] = None + yield chunk # generate reply ####################################### debug_msg({'prompt': prompt, 'generate_params': generate_params}) @@ -575,6 +831,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e seen_content = answer chunk = text_streaming_chunk(new_content) + if include_usage: + chunk['usage'] = None yield chunk completion_token_count = len(encode(answer)[0]) @@ -584,13 +842,27 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e chunk = text_streaming_chunk(suffix) chunk[resp_list][0]["finish_reason"] = stop_reason - chunk["usage"] = { + usage = { "prompt_tokens": token_count, "completion_tokens": completion_token_count, "total_tokens": token_count + completion_token_count } - yield chunk + if include_usage: + chunk['usage'] = None + yield chunk + # Separate usage-only chunk with choices: [] per OpenAI spec + yield { + "id": cmpl_id, + "object": object_type, + "created": created_time, + "model": shared.model_name, + "system_fingerprint": None, + resp_list: [], + "usage": usage + } + else: + yield chunk def chat_completions(body: dict, is_legacy: bool = False, stop_event=None) -> dict: diff --git a/extensions/openai/models.py b/extensions/openai/models.py index d6ef119d..c879a860 100644 --- a/extensions/openai/models.py +++ b/extensions/openai/models.py @@ -1,4 +1,4 @@ -from modules import shared, ui +from modules import loaders, shared from modules.logging_colors import logger from modules.LoRA import add_lora_to_model from modules.models import load_model, unload_model @@ -20,10 +20,14 @@ def list_models(): def list_models_openai_format(): """Returns model list in OpenAI API format""" - model_names = get_available_models() + if shared.model_name and shared.model_name != 'None': + data = [model_info_dict(shared.model_name)] + else: + data = [] + return { "object": "list", - "data": [model_info_dict(name) for name in model_names] + "data": data } @@ -50,7 +54,7 @@ def _load_model(data): # parameters exposed in the UI. Never allow security-sensitive # flags like trust_remote_code or extra_flags to be set via the API. blocked_keys = {'extra_flags'} - allowed_keys = set(ui.list_model_elements()) - blocked_keys + allowed_keys = set(loaders.list_model_elements()) - blocked_keys if args: for k in args: if k in allowed_keys and hasattr(shared.args, k): diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 7a13638d..a0d5deb8 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -21,6 +21,7 @@ import extensions.openai.completions as OAIcompletions import extensions.openai.logits as OAIlogits import extensions.openai.models as OAImodels from extensions.openai.tokens import token_count, token_decode, token_encode +from extensions.openai.errors import OpenAIError from extensions.openai.utils import _start_cloudflared from modules import shared from modules.logging_colors import logger @@ -94,6 +95,20 @@ app.add_middleware( ) +@app.exception_handler(OpenAIError) +async def openai_error_handler(request: Request, exc: OpenAIError): + error_type = "server_error" if exc.code >= 500 else "invalid_request_error" + return JSONResponse( + status_code=exc.code, + content={"error": { + "message": exc.message, + "type": error_type, + "param": getattr(exc, 'param', None), + "code": None + }} + ) + + @app.middleware("http") async def validate_host_header(request: Request, call_next): # Be strict about only approving access to localhost by default @@ -119,6 +134,12 @@ async def openai_completions(request: Request, request_data: CompletionRequest): is_legacy = "/generate" in path if request_data.stream: + if (request_data.n or 1) > 1: + return JSONResponse( + status_code=400, + content={"error": {"message": "n > 1 is not supported with streaming.", "type": "invalid_request_error", "param": "n", "code": None}} + ) + stop_event = threading.Event() async def generator(): @@ -130,6 +151,8 @@ async def openai_completions(request: Request, request_data: CompletionRequest): break yield {"data": json.dumps(resp)} + + yield {"data": "[DONE]"} finally: stop_event.set() response.close() @@ -170,6 +193,8 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion break yield {"data": json.dumps(resp)} + + yield {"data": "[DONE]"} finally: stop_event.set() response.close() @@ -433,10 +458,13 @@ def run_server(): # In the server configuration: server_addrs = [] - if os.environ.get('OPENEDAI_ENABLE_IPV6', shared.args.api_enable_ipv6): - server_addrs.append('[::]' if shared.args.listen else '[::1]') - if not os.environ.get('OPENEDAI_DISABLE_IPV4', shared.args.api_disable_ipv4): - server_addrs.append('0.0.0.0' if shared.args.listen else '127.0.0.1') + if shared.args.listen and shared.args.listen_host: + server_addrs.append(shared.args.listen_host) + else: + if os.environ.get('OPENEDAI_ENABLE_IPV6', shared.args.api_enable_ipv6): + server_addrs.append('[::]' if shared.args.listen else '[::1]') + if not os.environ.get('OPENEDAI_DISABLE_IPV4', shared.args.api_disable_ipv4): + server_addrs.append('0.0.0.0' if shared.args.listen else '127.0.0.1') if not server_addrs: raise Exception('you MUST enable IPv6 or IPv4 for the API to work') @@ -447,11 +475,11 @@ def run_server(): port, shared.args.public_api_id, max_attempts=3, - on_start=lambda url: logger.info(f'OpenAI-compatible API URL:\n\n{url}\n') + on_start=lambda url: logger.info(f'OpenAI-compatible API URL:\n\n{url}/v1\n') ) else: url_proto = 'https://' if (ssl_certfile and ssl_keyfile) else 'http://' - urls = [f'{url_proto}{addr}:{port}' for addr in server_addrs] + urls = [f'{url_proto}{addr}:{port}/v1' for addr in server_addrs] if len(urls) > 1: logger.info('OpenAI-compatible API URLs:\n\n' + '\n'.join(urls) + '\n') else: diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index e48b7b60..80831c44 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -99,6 +99,10 @@ class ToolCall(BaseModel): function: FunctionCall +class StreamOptions(BaseModel): + include_usage: bool | None = False + + class CompletionRequestParams(BaseModel): model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.") prompt: str | List[str] | None = Field(default=None, description="Text prompt for completion. Can also use 'messages' format for multimodal.") @@ -109,10 +113,11 @@ class CompletionRequestParams(BaseModel): logit_bias: dict | None = None logprobs: int | None = None max_tokens: int | None = 512 - n: int | None = Field(default=1, description="Unused parameter.") + n: int | None = Field(default=1, description="Number of completions to generate. Only supported without streaming.") presence_penalty: float | None = shared.args.presence_penalty stop: str | List[str] | None = None stream: bool | None = False + stream_options: StreamOptions | None = None suffix: str | None = None temperature: float | None = shared.args.temperature top_p: float | None = shared.args.top_p @@ -145,16 +150,27 @@ class ChatCompletionRequestParams(BaseModel): function_call: str | dict | None = Field(default=None, description="Unused parameter.") functions: List[dict] | None = Field(default=None, description="Unused parameter.") tools: List[dict] | None = Field(default=None, description="Tools signatures passed via MCP.") + tool_choice: str | dict | None = Field(default=None, description="Controls tool use: 'auto', 'none', 'required', or {\"type\": \"function\", \"function\": {\"name\": \"...\"}}.") logit_bias: dict | None = None + logprobs: bool | None = None + top_logprobs: int | None = None max_tokens: int | None = None + max_completion_tokens: int | None = None n: int | None = Field(default=1, description="Unused parameter.") presence_penalty: float | None = shared.args.presence_penalty stop: str | List[str] | None = None stream: bool | None = False + stream_options: StreamOptions | None = None temperature: float | None = shared.args.temperature top_p: float | None = shared.args.top_p user: str | None = Field(default=None, description="Unused parameter.") + @model_validator(mode='after') + def resolve_max_tokens(self): + if self.max_tokens is None and self.max_completion_tokens is not None: + self.max_tokens = self.max_completion_tokens + return self + mode: str = Field(default='instruct', description="Valid options: instruct, chat, chat-instruct.") instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/user_data/instruction-templates. If not set, the correct template will be automatically obtained from the model metadata.") diff --git a/extensions/openai/utils.py b/extensions/openai/utils.py index f4a31d1a..2b414769 100644 --- a/extensions/openai/utils.py +++ b/extensions/openai/utils.py @@ -1,8 +1,5 @@ import base64 -import json import os -import random -import re import time import traceback from typing import Callable, Optional @@ -55,473 +52,3 @@ def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_star time.sleep(3) raise Exception('Could not start cloudflared.') - - -def getToolCallId() -> str: - letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789" - b = [random.choice(letter_bytes) for _ in range(8)] - return "call_" + "".join(b).lower() - - -def checkAndSanitizeToolCallCandidate(candidate_dict: dict, tool_names: list[str]): - # check if property 'function' exists and is a dictionary, otherwise adapt dict - if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str): - candidate_dict = {"type": "function", "function": candidate_dict} - if 'function' in candidate_dict and isinstance(candidate_dict['function'], str): - candidate_dict['name'] = candidate_dict['function'] - del candidate_dict['function'] - candidate_dict = {"type": "function", "function": candidate_dict} - if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict): - # check if 'name' exists within 'function' and is part of known tools - if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names: - candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value - # map property 'parameters' used by some older models to 'arguments' - if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]: - candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"] - del candidate_dict["function"]["parameters"] - return candidate_dict - return None - - -def _extractBalancedJson(text: str, start: int) -> str | None: - """Extract a balanced JSON object from text starting at the given position. - - Walks through the string tracking brace depth and string boundaries - to correctly handle arbitrary nesting levels. - """ - if start >= len(text) or text[start] != '{': - return None - depth = 0 - in_string = False - escape_next = False - for i in range(start, len(text)): - c = text[i] - if escape_next: - escape_next = False - continue - if c == '\\' and in_string: - escape_next = True - continue - if c == '"': - in_string = not in_string - continue - if in_string: - continue - if c == '{': - depth += 1 - elif c == '}': - depth -= 1 - if depth == 0: - return text[start:i + 1] - return None - - -def _parseChannelToolCalls(answer: str, tool_names: list[str]): - """Parse channel-based tool calls used by GPT-OSS and similar models. - - Format: - <|channel|>commentary to=functions.func_name <|constrain|>json<|message|>{"arg": "value"} - """ - matches = [] - for m in re.finditer( - r'<\|channel\|>commentary to=functions\.([^<\s]+)\s*(?:<\|constrain\|>json)?<\|message\|>', - answer - ): - func_name = m.group(1).strip() - if func_name not in tool_names: - continue - json_str = _extractBalancedJson(answer, m.end()) - if json_str is None: - continue - try: - arguments = json.loads(json_str) - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) - except json.JSONDecodeError: - pass - return matches - - -def _parseBareNameToolCalls(answer: str, tool_names: list[str]): - """Parse bare function-name style tool calls used by Mistral and similar models. - - Format: - functionName{"arg": "value"} - Multiple calls are concatenated directly or separated by whitespace. - """ - matches = [] - # Match tool name followed by opening brace, then extract balanced JSON - escaped_names = [re.escape(name) for name in tool_names] - pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{' - for match in re.finditer(pattern, answer): - text = match.group(0) - name = None - for n in tool_names: - if text.startswith(n): - name = n - break - if not name: - continue - brace_start = match.end() - 1 - json_str = _extractBalancedJson(answer, brace_start) - if json_str is None: - continue - try: - arguments = json.loads(json_str) - matches.append({ - "type": "function", - "function": { - "name": name, - "arguments": arguments - } - }) - except json.JSONDecodeError: - pass - return matches - - -def _parseXmlParamToolCalls(answer: str, tool_names: list[str]): - """Parse XML-parameter style tool calls used by Qwen3.5 and similar models. - - Format: - - - value - - - """ - matches = [] - for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL): - tc_content = tc_match.group(1) - func_match = re.search(r']+)>', tc_content) - if not func_match: - continue - func_name = func_match.group(1).strip() - if func_name not in tool_names: - continue - arguments = {} - for param_match in re.finditer(r']+)>\s*(.*?)\s*', tc_content, re.DOTALL): - param_name = param_match.group(1).strip() - param_value = param_match.group(2).strip() - try: - param_value = json.loads(param_value) - except (json.JSONDecodeError, ValueError): - pass # keep as string - arguments[param_name] = param_value - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) - return matches - - -def _parseKimiToolCalls(answer: str, tool_names: list[str]): - """Parse Kimi-K2-style tool calls using pipe-delimited tokens. - - Format: - <|tool_calls_section_begin|> - <|tool_call_begin|>functions.func_name:index<|tool_call_argument_begin|>{"arg": "value"}<|tool_call_end|> - <|tool_calls_section_end|> - """ - matches = [] - for m in re.finditer( - r'<\|tool_call_begin\|>\s*(?:functions\.)?(\S+?)(?::\d+)?\s*<\|tool_call_argument_begin\|>\s*', - answer - ): - func_name = m.group(1).strip() - if func_name not in tool_names: - continue - json_str = _extractBalancedJson(answer, m.end()) - if json_str is None: - continue - try: - arguments = json.loads(json_str) - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) - except json.JSONDecodeError: - pass - return matches - - -def _parseMiniMaxToolCalls(answer: str, tool_names: list[str]): - """Parse MiniMax-style tool calls using invoke/parameter XML tags. - - Format: - - - value - - - """ - matches = [] - for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL): - tc_content = tc_match.group(1) - # Split on to handle multiple parallel calls in one block - for invoke_match in re.finditer(r'(.*?)', tc_content, re.DOTALL): - func_name = invoke_match.group(1).strip() - if func_name not in tool_names: - continue - invoke_body = invoke_match.group(2) - arguments = {} - for param_match in re.finditer(r'\s*(.*?)\s*', invoke_body, re.DOTALL): - param_name = param_match.group(1).strip() - param_value = param_match.group(2).strip() - try: - param_value = json.loads(param_value) - except (json.JSONDecodeError, ValueError): - pass # keep as string - arguments[param_name] = param_value - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) - return matches - - -def _parseDeepSeekToolCalls(answer: str, tool_names: list[str]): - """Parse DeepSeek-style tool calls using fullwidth Unicode token delimiters. - - Format: - <|tool▁calls▁begin|><|tool▁call▁begin|>func_name<|tool▁sep|>{"arg": "value"}<|tool▁call▁end|><|tool▁calls▁end|> - """ - matches = [] - for m in re.finditer( - r'<|tool▁call▁begin|>\s*(\S+?)\s*<|tool▁sep|>\s*', - answer - ): - func_name = m.group(1).strip() - if func_name not in tool_names: - continue - json_str = _extractBalancedJson(answer, m.end()) - if json_str is None: - continue - try: - arguments = json.loads(json_str) - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) - except json.JSONDecodeError: - pass - return matches - - -def _parseGlmToolCalls(answer: str, tool_names: list[str]): - """Parse GLM-style tool calls using arg_key/arg_value XML pairs. - - Format: - function_name - key1 - value1 - - """ - matches = [] - for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL): - tc_content = tc_match.group(1) - # First non-tag text is the function name - name_match = re.match(r'([^<\s]+)', tc_content.strip()) - if not name_match: - continue - func_name = name_match.group(1).strip() - if func_name not in tool_names: - continue - # Extract arg_key/arg_value pairs - keys = [k.group(1).strip() for k in re.finditer(r'\s*(.*?)\s*', tc_content, re.DOTALL)] - vals = [v.group(1).strip() for v in re.finditer(r'\s*(.*?)\s*', tc_content, re.DOTALL)] - if len(keys) != len(vals): - continue - arguments = {} - for k, v in zip(keys, vals): - try: - v = json.loads(v) - except (json.JSONDecodeError, ValueError): - pass # keep as string - arguments[k] = v - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) - return matches - - -def _parsePythonicToolCalls(answer: str, tool_names: list[str]): - """Parse pythonic-style tool calls used by Llama 4 and similar models. - - Format: - [func_name(param1="value1", param2="value2"), func_name2(...)] - """ - matches = [] - # Match a bracketed list of function calls - bracket_match = re.search(r'\[([^\[\]]+)\]', answer) - if not bracket_match: - return matches - - inner = bracket_match.group(1) - - # Build pattern for known tool names - escaped_names = [re.escape(name) for name in tool_names] - name_pattern = '|'.join(escaped_names) - - for call_match in re.finditer( - r'(' + name_pattern + r')\(([^)]*)\)', - inner - ): - func_name = call_match.group(1) - params_str = call_match.group(2).strip() - arguments = {} - - if params_str: - # Parse key="value" pairs, handling commas inside quoted values - for param_match in re.finditer( - r'(\w+)\s*=\s*("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\)]+)', - params_str - ): - param_name = param_match.group(1) - param_value = param_match.group(2).strip() - # Strip surrounding quotes - if (param_value.startswith('"') and param_value.endswith('"')) or \ - (param_value.startswith("'") and param_value.endswith("'")): - param_value = param_value[1:-1] - # Try to parse as JSON for numeric/bool/null values - try: - param_value = json.loads(param_value) - except (json.JSONDecodeError, ValueError): - pass - arguments[param_name] = param_value - - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) - - return matches - - -def parseToolCall(answer: str, tool_names: list[str]): - matches = [] - - # abort on very short answers to save computation cycles - if len(answer) < 10: - return matches - - # Check for DeepSeek-style tool calls (fullwidth Unicode token delimiters) - matches = _parseDeepSeekToolCalls(answer, tool_names) - if matches: - return matches - - # Check for Kimi-K2-style tool calls (pipe-delimited tokens) - matches = _parseKimiToolCalls(answer, tool_names) - if matches: - return matches - - # Check for channel-based tool calls (e.g. GPT-OSS format) - matches = _parseChannelToolCalls(answer, tool_names) - if matches: - return matches - - # Check for MiniMax-style tool calls (invoke/parameter XML tags) - matches = _parseMiniMaxToolCalls(answer, tool_names) - if matches: - return matches - - # Check for GLM-style tool calls (arg_key/arg_value XML pairs) - matches = _parseGlmToolCalls(answer, tool_names) - if matches: - return matches - - # Check for XML-parameter style tool calls (e.g. Qwen3.5 format) - matches = _parseXmlParamToolCalls(answer, tool_names) - if matches: - return matches - - # Check for bare function-name style tool calls (e.g. Mistral format) - matches = _parseBareNameToolCalls(answer, tool_names) - if matches: - return matches - - # Check for pythonic-style tool calls (e.g. Llama 4 format) - matches = _parsePythonicToolCalls(answer, tool_names) - if matches: - return matches - - # Define the regex pattern to find the JSON content wrapped in , , , and other tags observed from various models - patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)"] - - for pattern in patterns: - for match in re.finditer(pattern, answer, re.DOTALL): - # print(match.group(2)) - if match.group(2) is None: - continue - # remove backtick wraps if present - candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip()) - candidate = re.sub(r"```$", "", candidate.strip()) - # unwrap inner tags - candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL) - # llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually - if re.search(r"\}\s*\n\s*\{", candidate) is not None: - candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate) - if not candidate.strip().startswith("["): - candidate = "[" + candidate + "]" - - candidates = [] - try: - # parse the candidate JSON into a dictionary - candidates = json.loads(candidate) - if not isinstance(candidates, list): - candidates = [candidates] - except json.JSONDecodeError: - # Ignore invalid JSON silently - continue - - for candidate_dict in candidates: - checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names) - if checked_candidate is not None: - matches.append(checked_candidate) - - # last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags - if len(matches) == 0: - try: - candidate = answer - # llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually - if re.search(r"\}\s*\n\s*\{", candidate) is not None: - candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate) - if not candidate.strip().startswith("["): - candidate = "[" + candidate + "]" - # parse the candidate JSON into a dictionary - candidates = json.loads(candidate) - if not isinstance(candidates, list): - candidates = [candidates] - for candidate_dict in candidates: - checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names) - if checked_candidate is not None: - matches.append(checked_candidate) - except json.JSONDecodeError: - # Ignore invalid JSON silently - pass - - return matches diff --git a/js/global_scope_js.js b/js/global_scope_js.js index 62b31d37..92f65622 100644 --- a/js/global_scope_js.js +++ b/js/global_scope_js.js @@ -269,7 +269,49 @@ function removeLastClick() { document.getElementById("Remove-last").click(); } +function autoScrollToBottom() { + if (!window.isScrolled) { + const chatParent = document.getElementById("chat")?.parentNode?.parentNode?.parentNode; + if (chatParent) { + const maxScroll = chatParent.scrollHeight - chatParent.clientHeight; + if (maxScroll > 0 && chatParent.scrollTop < maxScroll - 1) { + chatParent.scrollTop = maxScroll; + } + } + } +} + +function updateInstructPadding() { + const chatElement = document.getElementById("chat"); + if (chatElement && chatElement.getAttribute("data-mode") === "instruct") { + const messagesContainer = chatElement.querySelector(".messages"); + const lastChild = messagesContainer?.lastElementChild; + const prevSibling = lastChild?.previousElementSibling; + if (lastChild && prevSibling && chatElement.offsetHeight > 0) { + let bufferHeight = Math.max(0, Math.max(window.innerHeight - 128 - 84, window.innerHeight - prevSibling.offsetHeight - 84) - lastChild.offsetHeight); + if (window.innerWidth <= 924) { + bufferHeight = Math.max(0, bufferHeight - 32); + } + messagesContainer.style.paddingBottom = `${bufferHeight}px`; + } + } +} + +let pendingMorphdomData = null; +let morphdomRafId = null; + function handleMorphdomUpdate(data) { + pendingMorphdomData = data; + if (!morphdomRafId) { + morphdomRafId = requestAnimationFrame(() => { + morphdomRafId = null; + applyMorphdomUpdate(pendingMorphdomData); + pendingMorphdomData = null; + }); + } +} + +function applyMorphdomUpdate(data) { // Determine target element and use it as query scope var target_element, target_html; if (data.last_message_only) { @@ -283,27 +325,21 @@ function handleMorphdomUpdate(data) { const queryScope = target_element; - // Track open blocks + // Track open blocks and store their scroll positions const openBlocks = new Set(); + const scrollPositions = {}; queryScope.querySelectorAll(".thinking-block").forEach(block => { const blockId = block.getAttribute("data-block-id"); - // If block exists and is open, add to open set if (blockId && block.hasAttribute("open")) { openBlocks.add(blockId); - } - }); - - // Store scroll positions for any open blocks - const scrollPositions = {}; - queryScope.querySelectorAll(".thinking-block[open]").forEach(block => { - const content = block.querySelector(".thinking-content"); - const blockId = block.getAttribute("data-block-id"); - if (content && blockId) { - const isAtBottom = Math.abs((content.scrollHeight - content.scrollTop) - content.clientHeight) < 5; - scrollPositions[blockId] = { - position: content.scrollTop, - isAtBottom: isAtBottom - }; + const content = block.querySelector(".thinking-content"); + if (content) { + const isAtBottom = Math.abs((content.scrollHeight - content.scrollTop) - content.clientHeight) < 5; + scrollPositions[blockId] = { + position: content.scrollTop, + isAtBottom: isAtBottom + }; + } } }); @@ -313,8 +349,8 @@ function handleMorphdomUpdate(data) { { onBeforeElUpdated: function(fromEl, toEl) { // Preserve code highlighting - if (fromEl.tagName === "PRE" && fromEl.querySelector("code[data-highlighted]")) { - const fromCode = fromEl.querySelector("code"); + if (fromEl.tagName === "PRE") { + const fromCode = fromEl.querySelector("code[data-highlighted]"); const toCode = toEl.querySelector("code"); if (fromCode && toCode && fromCode.textContent === toCode.textContent) { @@ -359,10 +395,23 @@ function handleMorphdomUpdate(data) { } ); + // Syntax highlighting and LaTeX + if (window.doSyntaxHighlighting) { + window.doSyntaxHighlighting(); + } + + // Auto-scroll runs both before and after padding update. + // Before: so content growth isn't hidden by padding absorption. + // After: so padding-added space is also scrolled into view. + autoScrollToBottom(); + updateInstructPadding(); + autoScrollToBottom(); + // Add toggle listeners for new blocks queryScope.querySelectorAll(".thinking-block").forEach(block => { if (!block._hasToggleListener) { block.addEventListener("toggle", function(e) { + const wasScrolled = window.isScrolled; if (this.open) { const content = this.querySelector(".thinking-content"); if (content) { @@ -371,6 +420,12 @@ function handleMorphdomUpdate(data) { }, 0); } } + autoScrollToBottom(); + updateInstructPadding(); + autoScrollToBottom(); + // Restore scroll state so the browser's layout adjustment + // from the toggle doesn't disable auto-scroll + window.isScrolled = wasScrolled; }); block._hasToggleListener = true; } diff --git a/js/main.js b/js/main.js index 1317e9e7..f05f93c6 100644 --- a/js/main.js +++ b/js/main.js @@ -2,6 +2,12 @@ // Main // ------------------------------------------------ +// Sync highlight.js theme with the actual Gradio theme +var defined_hljs_css = document.body.classList.contains("dark") ? "file/css/highlightjs/github-dark.min.css" : "file/css/highlightjs/github.min.css"; +if (document.getElementById("highlight-css").getAttribute("href") !== defined_hljs_css) { + document.getElementById("highlight-css").setAttribute("href", defined_hljs_css); +} + let main_parent = document.getElementById("chat-tab").parentNode; let extensions = document.getElementById("extensions"); @@ -145,10 +151,13 @@ targetElement.classList.add("pretty_scrollbar"); targetElement.classList.add("chat-parent"); window.isScrolled = false; let scrollTimeout; +let lastScrollTop = 0; +let lastScrollHeight = 0; +let lastClientHeight = 0; targetElement.addEventListener("scroll", function() { let diff = targetElement.scrollHeight - targetElement.clientHeight; - let isAtBottomNow = Math.abs(targetElement.scrollTop - diff) <= 10 || diff == 0; + let isAtBottomNow = Math.abs(targetElement.scrollTop - diff) <= 10 || diff <= 0; // Add scrolling class to disable hover effects if (window.isScrolled || !isAtBottomNow) { @@ -157,9 +166,12 @@ targetElement.addEventListener("scroll", function() { if(isAtBottomNow) { window.isScrolled = false; - } else { + } else if (targetElement.scrollTop < lastScrollTop && targetElement.scrollHeight >= lastScrollHeight && targetElement.clientHeight <= lastClientHeight) { window.isScrolled = true; } + lastScrollTop = targetElement.scrollTop; + lastScrollHeight = targetElement.scrollHeight; + lastClientHeight = targetElement.clientHeight; // Clear previous timeout and set new one clearTimeout(scrollTimeout); @@ -170,61 +182,28 @@ targetElement.addEventListener("scroll", function() { }); // Create a MutationObserver instance -const observer = new MutationObserver(function(mutations) { - // Check if this is just the scrolling class being toggled - const isScrollingClassOnly = mutations.every(mutation => - mutation.type === "attributes" && - mutation.attributeName === "class" && - mutation.target === targetElement - ); - +const observer = new MutationObserver(function() { if (targetElement.classList.contains("_generating")) { typing.parentNode.classList.add("visible-dots"); document.getElementById("stop").style.display = "flex"; document.getElementById("Generate").style.display = "none"; + // If the user is near the bottom, ensure auto-scroll is enabled + // for the new reply. This catches cases where isScrolled was + // incorrectly set to true by layout shifts during page load, etc. + const diff = targetElement.scrollHeight - targetElement.clientHeight; + if (Math.abs(targetElement.scrollTop - diff) <= 10 || diff <= 0) { + window.isScrolled = false; + } } else { typing.parentNode.classList.remove("visible-dots"); document.getElementById("stop").style.display = "none"; document.getElementById("Generate").style.display = "flex"; } - - doSyntaxHighlighting(); - - if (!window.isScrolled && !isScrollingClassOnly) { - const maxScroll = targetElement.scrollHeight - targetElement.clientHeight; - if (maxScroll > 0 && targetElement.scrollTop < maxScroll - 1) { - targetElement.scrollTop = maxScroll; - } - } - - const chatElement = document.getElementById("chat"); - if (chatElement && chatElement.getAttribute("data-mode") === "instruct") { - const messagesContainer = chatElement.querySelector(".messages"); - const lastChild = messagesContainer?.lastElementChild; - const prevSibling = lastChild?.previousElementSibling; - if (lastChild && prevSibling) { - // Add padding to the messages container to create room for the last message. - // The purpose of this is to avoid constant scrolling during streaming in - // instruct mode. - let bufferHeight = Math.max(0, Math.max(window.innerHeight - 128 - 84, window.innerHeight - prevSibling.offsetHeight - 84) - lastChild.offsetHeight); - - // Subtract header height when screen width is <= 924px - if (window.innerWidth <= 924) { - bufferHeight = Math.max(0, bufferHeight - 32); - } - - messagesContainer.style.paddingBottom = `${bufferHeight}px`; - } - } }); -// Configure the observer to watch for changes in the subtree and attributes +// Only watch for attribute changes on targetElement (e.g. _generating class) const config = { - childList: true, - subtree: true, - characterData: true, - attributeOldValue: true, - characterDataOldValue: true + attributes: true }; // Start observing the target element @@ -243,64 +222,76 @@ function isElementVisibleOnScreen(element) { ); } -function doSyntaxHighlighting() { +window.doSyntaxHighlighting = function() { const messageBodies = document.getElementById("chat").querySelectorAll(".message-body"); if (messageBodies.length > 0) { - observer.disconnect(); + let hasSeenVisible = false; - try { - let hasSeenVisible = false; + // Go from last message to first + for (let i = messageBodies.length - 1; i >= 0; i--) { + const messageBody = messageBodies[i]; - // Go from last message to first - for (let i = messageBodies.length - 1; i >= 0; i--) { - const messageBody = messageBodies[i]; + if (isElementVisibleOnScreen(messageBody)) { + hasSeenVisible = true; - if (isElementVisibleOnScreen(messageBody)) { - hasSeenVisible = true; + // Handle both code and math in a single pass through each message + const codeBlocks = messageBody.querySelectorAll("pre code:not([data-highlighted])"); + codeBlocks.forEach((codeBlock) => { + hljs.highlightElement(codeBlock); + codeBlock.setAttribute("data-highlighted", "true"); + codeBlock.classList.add("pretty_scrollbar"); + }); - // Handle both code and math in a single pass through each message - const codeBlocks = messageBody.querySelectorAll("pre code:not([data-highlighted])"); - codeBlocks.forEach((codeBlock) => { - hljs.highlightElement(codeBlock); - codeBlock.setAttribute("data-highlighted", "true"); - codeBlock.classList.add("pretty_scrollbar"); - }); - - // Only render math in visible elements - const mathContainers = messageBody.querySelectorAll("p, span, li, td, th, h1, h2, h3, h4, h5, h6, blockquote, figcaption, caption, dd, dt"); - mathContainers.forEach(container => { - if (isElementVisibleOnScreen(container)) { - renderMathInElement(container, { - delimiters: [ - { left: "$$", right: "$$", display: true }, - { left: "$", right: "$", display: false }, - { left: "\\(", right: "\\)", display: false }, - { left: "\\[", right: "\\]", display: true }, - ], - }); - } - }); - } else if (hasSeenVisible) { - // We've seen visible messages but this one is not visible - // Since we're going from last to first, we can break - break; - } + // Only render math in visible elements + const mathContainers = messageBody.querySelectorAll("p, span, li, td, th, h1, h2, h3, h4, h5, h6, blockquote, figcaption, caption, dd, dt"); + mathContainers.forEach(container => { + if (isElementVisibleOnScreen(container)) { + renderMathInElement(container, { + delimiters: [ + { left: "$$", right: "$$", display: true }, + { left: "$", right: "$", display: false }, + { left: "\\(", right: "\\)", display: false }, + { left: "\\[", right: "\\]", display: true }, + ], + }); + } + }); + } else if (hasSeenVisible) { + // We've seen visible messages but this one is not visible + // Since we're going from last to first, we can break + break; } - } finally { - observer.observe(targetElement, config); } } } +const doSyntaxHighlighting = window.doSyntaxHighlighting; //------------------------------------------------ // Add some scrollbars //------------------------------------------------ -const textareaElements = document.querySelectorAll(".add_scrollbar textarea"); -for(i = 0; i < textareaElements.length; i++) { - textareaElements[i].classList.remove("scroll-hide"); - textareaElements[i].classList.add("pretty_scrollbar"); - textareaElements[i].style.resize = "none"; +const scrollbarElements = document.querySelectorAll(".add_scrollbar textarea, .add_scrollbar .drag-drop-list"); +for(i = 0; i < scrollbarElements.length; i++) { + scrollbarElements[i].classList.remove("scroll-hide"); + scrollbarElements[i].classList.add("pretty_scrollbar"); + scrollbarElements[i].style.resize = "none"; +} + + +//------------------------------------------------ +// Tools: inject "Refresh list" link into the label +//------------------------------------------------ +const toolsTitle = document.querySelector("#tools-group > [data-testid='block-info']"); +const toolsInfo = toolsTitle ? toolsTitle.nextElementSibling : null; +if (toolsInfo) { + const refreshLink = document.createElement("span"); + refreshLink.textContent = " [Refresh list]"; + refreshLink.className = "tools-refresh-link"; + refreshLink.addEventListener("click", function(e) { + e.preventDefault(); + document.querySelector("#tools-refresh-btn").click(); + }); + toolsInfo.appendChild(refreshLink); } //------------------------------------------------ @@ -561,6 +552,38 @@ document.querySelectorAll(".focus-on-chat-input").forEach(element => { }); }); +//------------------------------------------------ +// "New chat" hover menu with incognito option +//------------------------------------------------ + +(function() { + const newChatBtn = document.getElementById("new-chat-btn"); + + const wrapper = document.createElement("div"); + wrapper.id = "new-chat-wrapper"; + newChatBtn.replaceWith(wrapper); + wrapper.appendChild(newChatBtn); + + const arrow = document.createElement("span"); + arrow.className = "new-chat-arrow"; + arrow.textContent = "\u25BE"; + + const menu = document.createElement("div"); + menu.className = "new-chat-menu"; + const option = document.createElement("div"); + option.className = "new-chat-menu-item"; + option.textContent = "Incognito chat"; + menu.appendChild(option); + + arrow.appendChild(menu); + wrapper.appendChild(arrow); + + option.addEventListener("click", function(e) { + e.stopPropagation(); + document.querySelector("#incognito-chat-btn").click(); + }); +})(); + //------------------------------------------------ // Fix a border around the "past chats" menu //------------------------------------------------ @@ -1090,15 +1113,13 @@ document.fonts.addEventListener("loadingdone", (event) => { const currentHeight = chatInputRow.offsetHeight; const heightDifference = currentHeight - originalHeight; chatParent.style.marginBottom = `${originalMarginBottom + heightDifference}px`; + if (!window.isScrolled) { + chatParent.scrollTop = chatParent.scrollHeight - chatParent.clientHeight; + } } - // Watch for changes that might affect height - const observer = new MutationObserver(updateMargin); - observer.observe(chatInputRow, { - childList: true, - subtree: true, - attributes: true - }); + // Watch for size changes that affect height + new ResizeObserver(updateMargin).observe(chatInputRow); // Also listen for window resize window.addEventListener("resize", updateMargin); diff --git a/modules/chat.py b/modules/chat.py index 36d373d6..e4fcaabe 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -6,12 +6,13 @@ import json import pprint import re import shutil +import threading import time from datetime import datetime from functools import partial from pathlib import Path -import gradio as gr +import markupsafe import yaml from jinja2.ext import loopcontrols from jinja2.sandbox import ImmutableSandboxedEnvironment @@ -23,10 +24,12 @@ from modules.extensions import apply_extensions from modules.html_generator import ( chat_html_wrapper, convert_to_markdown, + extract_thinking_block, make_thumbnail ) from modules.image_utils import open_image_safely from modules.logging_colors import logger +from modules.reasoning import THINKING_FORMATS from modules.text_generation import ( generate_reply, get_encoded_length, @@ -41,6 +44,8 @@ from modules.utils import ( ) from modules.web_search import add_web_search_attachments +_history_file_lock = threading.Lock() + def strftime_now(format): return datetime.now().strftime(format) @@ -75,8 +80,22 @@ jinja_env = ImmutableSandboxedEnvironment( lstrip_blocks=True, extensions=[loopcontrols] ) + + +def custom_tojson(value, indent=None, ensure_ascii=True): + return markupsafe.Markup(json.dumps(value, indent=indent, ensure_ascii=ensure_ascii)) + + +jinja_env.filters["tojson"] = custom_tojson jinja_env.globals["strftime_now"] = strftime_now + +def _raise_exception(message): + raise ValueError(message) + + +jinja_env.globals["raise_exception"] = _raise_exception + _template_cache = {} @@ -150,6 +169,49 @@ def _deserialize_tool_call_arguments(tool_calls): return result +def _expand_tool_sequence(tool_seq): + """Expand a tool_sequence list into API messages. + + Returns a list of dicts (role: assistant with tool_calls, or role: tool). + If any tool_call IDs are missing a matching tool result, a synthetic + empty result is inserted so the prompt is never malformed. + """ + messages = [] + expected_ids = [] + seen_ids = set() + + for item in tool_seq: + if 'tool_calls' in item: + deserialized = _deserialize_tool_call_arguments(item['tool_calls']) + messages.append({ + "role": "assistant", + "content": item.get('content', ''), + "tool_calls": deserialized + }) + for tc in item['tool_calls']: + tc_id = tc.get('id', '') + if tc_id: + expected_ids.append(tc_id) + elif item.get('role') == 'tool': + messages.append({ + "role": "tool", + "content": item['content'], + "tool_call_id": item.get('tool_call_id', '') + }) + seen_ids.add(item.get('tool_call_id', '')) + + # Fill in synthetic results for any orphaned tool call IDs + for tc_id in expected_ids: + if tc_id not in seen_ids: + messages.append({ + "role": "tool", + "content": "", + "tool_call_id": tc_id + }) + + return messages + + def generate_chat_prompt(user_input, state, **kwargs): impersonate = kwargs.get('impersonate', False) _continue = kwargs.get('_continue', False) @@ -185,6 +247,7 @@ def generate_chat_prompt(user_input, state, **kwargs): name1=state['name1'], name2=state['name2'], user_bio=replace_character_names(state['user_bio'], state['name1'], state['name2']), + tools=state['tools'] if 'tools' in state else None, ) messages = [] @@ -296,7 +359,18 @@ def generate_chat_prompt(user_input, state, **kwargs): if entry_meta.get('tool_calls') and messages[insert_pos].get('role') == 'assistant': messages[insert_pos]['tool_calls'] = _deserialize_tool_call_arguments(entry_meta['tool_calls']) - if user_msg not in ['', '<|BEGIN-VISIBLE-CHAT|>']: + # Expand tool_sequence from metadata (inserted AFTER assistant so that + # the final order is: user → tool_calls → tool_results → final_answer) + meta_key = f"assistant_{row_idx}" + tool_seq = metadata.get(meta_key, {}).get('tool_sequence', []) + if tool_seq: + for msg in reversed(_expand_tool_sequence(tool_seq)): + messages.insert(insert_pos, msg) + + if entry_meta.get('role') == 'system': + if user_msg: + messages.insert(insert_pos, {"role": "system", "content": user_msg}) + elif user_msg not in ['', '<|BEGIN-VISIBLE-CHAT|>']: # Check for user message attachments in metadata user_key = f"user_{row_idx}" enhanced_user_msg = user_msg @@ -365,6 +439,12 @@ def generate_chat_prompt(user_input, state, **kwargs): messages.append({"role": "user", "content": user_input}) + # Expand tool_sequence for the current entry (excluded from the + # history loop during regenerate — needed so the model sees prior + # tool calls and results when re-generating the final answer). + current_tool_seq = metadata.get(f"assistant_{len(history)}", {}).get('tool_sequence', []) + messages.extend(_expand_tool_sequence(current_tool_seq)) + if impersonate and state['mode'] != 'chat-instruct': messages.append({"role": "user", "content": "fake user message replace me"}) @@ -810,6 +890,8 @@ def generate_search_query(user_message, state): query = query.rsplit("", 1)[1] elif "<|start|>assistant<|channel|>final<|message|>" in query: query = query.rsplit("<|start|>assistant<|channel|>final<|message|>", 1)[1] + elif "<|channel|>final<|message|>" in query: + query = query.rsplit("<|channel|>final<|message|>", 1)[1] elif "" in query: query = query.rsplit("", 1)[1] @@ -884,7 +966,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess } else: text, visible_text = output['internal'][-1][0], output['visible'][-1][0] - if regenerate: + if regenerate and not state.get('_tool_turn'): row_idx = len(output['internal']) - 1 # Store the old response as a version before regenerating @@ -947,10 +1029,33 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess # Add timestamp for assistant's response at the start of generation update_message_metadata(output['metadata'], "assistant", row_idx, timestamp=get_current_timestamp(), model_name=shared.model_name) + # Detect if the template appended a thinking start tag to the prompt + thinking_prefix = None + if not _continue: + stripped_prompt = prompt.rstrip('\n') + for start_tag, end_tag, content_tag in THINKING_FORMATS: + if start_tag is not None and stripped_prompt.endswith(start_tag): + thinking_prefix = start_tag + break + + # When tools are active, buffer streaming output during potential tool + # call generation to prevent raw markup from leaking into the display. + _check_tool_markers = bool(state.get('tools')) + _last_visible_before_tool_buffer = None + if _check_tool_markers: + from modules.tool_parsing import streaming_tool_buffer_check, detect_tool_call_format + _tool_names = [t['function']['name'] for t in state['tools'] if 'function' in t and 'name' in t['function']] + _template_str = state.get('instruction_template_str', '') if state.get('mode') == 'instruct' else state.get('chat_template_str', '') + _, _streaming_markers, _check_bare_names = detect_tool_call_format(_template_str) + # Generate reply = None for j, reply in enumerate(generate_reply(prompt, state, stopping_strings=stopping_strings, is_chat=True, for_ui=for_ui)): + # Prepend thinking tag if the template appended it to the prompt + if thinking_prefix: + reply = thinking_prefix + reply + # Extract the reply if state['mode'] in ['chat', 'chat-instruct']: if not _continue: @@ -982,7 +1087,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')] # Keep version metadata in sync during streaming (for regeneration) - if regenerate: + if regenerate and not state.get('_tool_turn'): row_idx = len(output['internal']) - 1 key = f"assistant_{row_idx}" current_idx = output['metadata'][key]['current_version_index'] @@ -992,25 +1097,35 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess }) if is_stream: + if _check_tool_markers: + if streaming_tool_buffer_check(output['internal'][-1][1], markers=_streaming_markers, tool_names=_tool_names, check_bare_names=_check_bare_names): + continue + _last_visible_before_tool_buffer = output['visible'][-1][1] + yield output if _continue: - # Reprocess the entire internal text for extensions (like translation) - full_internal = output['internal'][-1][1] - if state['mode'] in ['chat', 'chat-instruct']: - full_visible = re.sub("(||{{user}})", state['name1'], full_internal) - else: - full_visible = full_internal + # Reprocess the entire internal text for extensions (like translation). + # Skip entirely when the visible text contains markers, + # since those only exist in visible (internal is cleared after each tool + # execution) and rebuilding from internal would destroy them. Output + # extensions also can't handle the raw markup safely. + if '' not in output['visible'][-1][1]: + full_internal = output['internal'][-1][1] + if state['mode'] in ['chat', 'chat-instruct']: + full_visible = re.sub("(||{{user}})", state['name1'], full_internal) + else: + full_visible = full_internal - full_visible = html.escape(full_visible) - if not state.get('_skip_output_extensions'): - output['visible'][-1][1] = apply_extensions('output', full_visible, state, is_chat=True) + full_visible = html.escape(full_visible) + if not state.get('_skip_output_extensions'): + output['visible'][-1][1] = apply_extensions('output', full_visible, state, is_chat=True) else: if not state.get('_skip_output_extensions'): output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state, is_chat=True) # Final sync for version metadata (in case streaming was disabled) - if regenerate: + if regenerate and not state.get('_tool_turn'): row_idx = len(output['internal']) - 1 key = f"assistant_{row_idx}" current_idx = output['metadata'][key]['current_version_index'] @@ -1019,6 +1134,13 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess 'visible_content': output['visible'][row_idx][1] }) + # When tool markers were detected during streaming, restore the last + # visible text from before buffering started so raw markup doesn't flash + # in the UI. The internal text is left intact so the caller can still + # parse tool calls from it. + if is_stream and _check_tool_markers and streaming_tool_buffer_check(output['internal'][-1][1], markers=_streaming_markers, tool_names=_tool_names, check_bare_names=_check_bare_names): + output['visible'][-1][1] = _last_visible_before_tool_buffer or '' + yield output @@ -1064,7 +1186,11 @@ def character_is_loaded(state, raise_exception=False): def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False): ''' - Same as above but returns HTML for the UI + Same as above but returns HTML for the UI. + When tools are selected, wraps generation in a loop that detects + tool calls, executes them, and re-generates until the model stops. + All tool output is consolidated into a single visible chat bubble + using metadata['assistant_N']['tool_sequence']. ''' if not character_is_loaded(state): @@ -1079,19 +1205,257 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False): send_dummy_message(text, state) send_dummy_reply(state['start_with'], state) - history = state['history'] + # On regenerate, clear old tool_sequence metadata so it gets rebuilt. + # Save it first so it can be stored per-version below. + # This must happen after the start_with logic above, which may remove + # and re-add messages, changing which row we operate on. + _old_tool_sequence = None + if regenerate: + history = state['history'] + meta = history.get('metadata', {}) + row_idx = len(history['internal']) - 1 + if row_idx >= 0: + _old_tool_sequence = meta.get(f'assistant_{row_idx}', {}).pop('tool_sequence', None) + + # Load tools if any are selected + selected = state.get('selected_tools', []) + parse_tool_call = None + _tool_parsers = None + if selected: + from modules.tool_use import load_tools, execute_tool + from modules.tool_parsing import parse_tool_call, get_tool_call_id, detect_tool_call_format + + if selected: + tool_defs, tool_executors = load_tools(selected) + state['tools'] = tool_defs + tool_func_names = [t['function']['name'] for t in tool_defs] + _template_str = state.get('instruction_template_str', '') if state.get('mode') == 'instruct' else state.get('chat_template_str', '') + _tool_parsers, _, _ = detect_tool_call_format(_template_str) + else: + tool_func_names = None + + visible_prefix = [] # Accumulated tool call summaries + results last_save_time = time.monotonic() save_interval = 8 - for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True, for_ui=True)): - yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'], last_message_only=(i > 0)), history - if i == 0: - time.sleep(0.125) # We need this to make sure the first update goes through + _tool_turn = 0 + while True: + history = state['history'] - current_time = time.monotonic() - # Save on first iteration or if save_interval seconds have passed - if i == 0 or (current_time - last_save_time) >= save_interval: + # Turn 0: use original flags; turns 2+: regenerate into the same entry. + # _tool_turn tells chatbot_wrapper to skip version creation/sync so + # that intermediate tool-loop regenerations don't pollute swipe history. + if _tool_turn > 0: + state['_tool_turn'] = True + state['_skip_output_extensions'] = True + + regen = regenerate if _tool_turn == 0 else True + cont = _continue if _tool_turn == 0 else False + cur_text = text if _tool_turn == 0 else '' + + for i, history in enumerate(generate_chat_reply(cur_text, state, regen, cont, loading_message=True, for_ui=True)): + # Prepend accumulated tool output to visible reply for display. + # Save and restore the original to prevent the markers from leaking + # back into chatbot_wrapper's shared output object, which would cause + # duplication on the next yield. + _original_visible = history['visible'][-1][1] if visible_prefix else None + if visible_prefix: + history['visible'][-1][1] = '\n\n'.join(visible_prefix + [_original_visible]) + + yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'], last_message_only=(i > 0)), history + + if visible_prefix: + history['visible'][-1][1] = _original_visible + + if i == 0: + # Save old tool_sequence into version 0 (created by chatbot_wrapper + # on the first yield). Only needed on the first regeneration when + # versions didn't previously exist. + if _old_tool_sequence is not None and _tool_turn == 0: + _ri = len(history['internal']) - 1 + _versions = history.get('metadata', {}).get(f'assistant_{_ri}', {}).get('versions', []) + if _versions and 'tool_sequence' not in _versions[0]: + _versions[0]['tool_sequence'] = _old_tool_sequence + _old_tool_sequence = None + + time.sleep(0.125) + + current_time = time.monotonic() + if i == 0 or (current_time - last_save_time) >= save_interval: + save_history(history, state['unique_id'], state['character_menu'], state['mode']) + last_save_time = current_time + + # Early stop on tool call detection + if tool_func_names and parse_tool_call(history['internal'][-1][1], tool_func_names, parsers=_tool_parsers): + break + + # Save the model's visible output before re-applying visible_prefix, + # so we can extract thinking content from just this turn's output. + _model_visible = history['visible'][-1][1] + + # Recover visible_prefix from existing visible text (e.g. on Continue + # after a previous session had tool calls). Extract all + # blocks and any text between them (thinking blocks, intermediate text). + if tool_func_names and not visible_prefix and _model_visible: + tc_matches = list(re.finditer(r'.*?', _model_visible, re.DOTALL)) + if tc_matches: + prefix_end = tc_matches[-1].end() + prefix = _model_visible[:prefix_end].strip() + if prefix: + visible_prefix = [prefix] + _model_visible = _model_visible[prefix_end:].strip() + + # Re-apply visible prefix to the final state after streaming completes. + # This is safe because we're no longer sharing the object with chatbot_wrapper. + if visible_prefix: + history['visible'][-1][1] = '\n\n'.join(visible_prefix + [_model_visible]) + + if tool_func_names: save_history(history, state['unique_id'], state['character_menu'], state['mode']) - last_save_time = current_time + + # Check for tool calls + if not tool_func_names or shared.stop_everything: + break + + answer = history['internal'][-1][1] + parsed_calls, content_prefix = parse_tool_call(answer, tool_func_names, return_prefix=True, parsers=_tool_parsers) if answer else (None, '') + + if not parsed_calls: + break # No tool calls — done + + # --- Process tool calls --- + row_idx = len(history['internal']) - 1 + meta = history.get('metadata', {}) + seq = meta.setdefault(f'assistant_{row_idx}', {}).setdefault('tool_sequence', []) + + def _render(): + return chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) + + # Serialize tool calls and build display headers in one pass + serialized = [] + tc_headers = [] + for tc in parsed_calls: + tc['id'] = get_tool_call_id() + fn_name = tc['function']['name'] + fn_args = tc['function'].get('arguments', {}) + + serialized.append({ + 'id': tc['id'], + 'type': 'function', + 'function': { + 'name': fn_name, + 'arguments': json.dumps(fn_args) if isinstance(fn_args, dict) else fn_args + } + }) + + if isinstance(fn_args, dict) and fn_args: + args_summary = ', '.join(f'{k}={json.dumps(v, ensure_ascii=False)}' for k, v in fn_args.items()) + elif isinstance(fn_args, dict): + args_summary = '' + else: + args_summary = str(fn_args) + + tc_headers.append(f'{fn_name}({args_summary})') + + seq_entry = {'tool_calls': serialized} + if content_prefix.strip(): + # Strip GPT-OSS channel tokens so they don't get double-wrapped + # by the template (which adds its own channel markup). + clean = content_prefix.strip() + if '<|channel|>' in clean and '<|message|>' in clean: + inner = clean.split('<|message|>', 1)[1] + if '<|end|>' in inner: + inner = inner.split('<|end|>', 1)[0] + clean = inner.strip() + if clean: + seq_entry['content'] = clean + seq.append(seq_entry) + + # Clear internal (raw tool markup) + history['internal'][-1][1] = '' + + # Preserve thinking block and intermediate text from this turn. + # content_prefix is the raw text before tool call syntax (returned + # by parse_tool_call); HTML-escape it and extract thinking to get + # the content the user should see. + content_text = html.escape(content_prefix) + thinking_content, intermediate = extract_thinking_block(content_text) + if thinking_content: + visible_prefix.append(f'<think>\n{thinking_content}\n</think>') + if intermediate and intermediate.strip(): + visible_prefix.append(intermediate.strip()) + + # Show placeholder accordions with "..." before execution starts + # (tool calls may be slow, e.g. web search). + pending_placeholders = [f'{h}\n...\n' for h in tc_headers] + history['visible'][-1][1] = '\n\n'.join(visible_prefix + pending_placeholders) + yield _render(), history + + # Execute tools, store results, and replace placeholders with real results + for i, tc in enumerate(parsed_calls): + # Check for stop request before each tool execution + if shared.stop_everything: + for j in range(i, len(parsed_calls)): + seq.append({'role': 'tool', 'content': 'Tool execution was cancelled by the user.', 'tool_call_id': parsed_calls[j]['id']}) + pending_placeholders[j] = f'{tc_headers[j]}\nCancelled\n' + + history['visible'][-1][1] = '\n\n'.join(visible_prefix + pending_placeholders) + yield _render(), history + break + + fn_name = tc['function']['name'] + fn_args = tc['function'].get('arguments', {}) + result = execute_tool(fn_name, fn_args, tool_executors) + + seq.append({'role': 'tool', 'content': result, 'tool_call_id': tc['id']}) + try: + pretty_result = json.dumps(json.loads(result), indent=2, ensure_ascii=False) + except (json.JSONDecodeError, TypeError): + pretty_result = result + + # Replace the placeholder with the real result + pending_placeholders[i] = f'{tc_headers[i]}\n{pretty_result}\n' + history['visible'][-1][1] = '\n\n'.join(visible_prefix + pending_placeholders) + yield _render(), history + + # Move completed tool calls into visible_prefix for next turns + visible_prefix.extend(pending_placeholders) + history['visible'][-1][1] = '\n\n'.join(visible_prefix) + save_history(history, state['unique_id'], state['character_menu'], state['mode']) + + state['history'] = history + _tool_turn += 1 + + state.pop('_tool_turn', None) + + # If output extensions were deferred during tool turns, apply them now + # to the final model response only (not to tool call markers). + if state.pop('_skip_output_extensions', None): + _model_visible = apply_extensions('output', _model_visible, state, is_chat=True) + if visible_prefix: + history['visible'][-1][1] = '\n\n'.join(visible_prefix + [_model_visible]) + else: + history['visible'][-1][1] = _model_visible + + yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']), history + + state['history'] = history + + # Sync version metadata so swipes show the full visible (with tool prefix) + if visible_prefix and history.get('metadata'): + row_idx = len(history['internal']) - 1 + key = f"assistant_{row_idx}" + meta_entry = history['metadata'].get(key, {}) + if 'versions' in meta_entry and 'current_version_index' in meta_entry: + current_idx = meta_entry['current_version_index'] + if current_idx < len(meta_entry['versions']): + version_update = { + 'content': history['internal'][row_idx][1], + 'visible_content': history['visible'][row_idx][1] + } + ts = meta_entry.get('tool_sequence') + if ts is not None: + version_update['tool_sequence'] = ts + meta_entry['versions'][current_idx].update(version_update) save_history(history, state['unique_id'], state['character_menu'], state['mode']) @@ -1164,7 +1528,7 @@ def redraw_html(history, name1, name2, mode, style, character, reset_cache=False return chat_html_wrapper(history, name1, name2, mode, style, character, reset_cache=reset_cache) -def start_new_chat(state): +def start_new_chat(state, unique_id=None): mode = state['mode'] # Initialize with empty metadata dictionary history = {'internal': [], 'visible': [], 'metadata': {}} @@ -1178,7 +1542,9 @@ def start_new_chat(state): # Add timestamp for assistant's greeting update_message_metadata(history['metadata'], "assistant", 0, timestamp=get_current_timestamp()) - unique_id = datetime.now().strftime('%Y%m%d-%H-%M-%S') + if unique_id is None: + unique_id = datetime.now().strftime('%Y%m%d-%H-%M-%S') + save_history(history, unique_id, state['character_menu'], state['mode']) return history @@ -1197,12 +1563,16 @@ def save_history(history, unique_id, character, mode): if shared.args.multi_user: return + if unique_id and unique_id.startswith('incognito-'): + return + p = get_history_file_path(unique_id, character, mode) if not p.parent.is_dir(): p.parent.mkdir(parents=True) - with open(p, 'w', encoding='utf-8') as f: - f.write(json.dumps(history, indent=4, ensure_ascii=False)) + with _history_file_lock: + with open(p, 'w', encoding='utf-8') as f: + f.write(json.dumps(history, indent=4, ensure_ascii=False)) def rename_history(old_id, new_id, character, mode): @@ -1333,6 +1703,7 @@ def load_history_after_deletion(state, idx): Loads the latest history for the given character in chat or chat-instruct mode, or the latest instruct history for instruct mode. ''' + import gradio as gr if shared.args.multi_user: return start_new_chat(state) @@ -1351,6 +1722,7 @@ def load_history_after_deletion(state, idx): def update_character_menu_after_deletion(idx): + import gradio as gr characters = utils.get_available_characters() idx = min(int(idx), len(characters) - 1) idx = max(0, idx) @@ -1383,6 +1755,9 @@ def save_last_chat_state(character, mode, unique_id): if shared.args.multi_user: return + if unique_id and unique_id.startswith('incognito-'): + return + state = load_last_chat_state() key = get_chat_state_key(character, mode) state["last_chats"][key] = unique_id @@ -1565,24 +1940,6 @@ def clear_character_for_ui(state): return state, state['name2'], state['context'], state['greeting'], None -def load_instruction_template(template): - if template == 'None': - return '' - - for filepath in [shared.user_data_dir / 'instruction-templates' / f'{template}.yaml', shared.user_data_dir / 'instruction-templates' / 'Alpaca.yaml']: - if filepath.exists(): - break - else: - return '' - - file_contents = open(filepath, 'r', encoding='utf-8').read() - data = yaml.safe_load(file_contents) - if 'instruction_template' in data: - return data['instruction_template'] - else: - return jinja_template_from_old_format(data) - - @functools.cache def load_character_memoized(character, name1, name2): return load_character(character, name1, name2) @@ -1590,10 +1947,12 @@ def load_character_memoized(character, name1, name2): @functools.cache def load_instruction_template_memoized(template): + from modules.models_settings import load_instruction_template return load_instruction_template(template) def upload_character(file, img_path, tavern=False): + import gradio as gr img = open_image_safely(img_path) decoded_file = file if isinstance(file, str) else file.decode('utf-8') try: @@ -1647,6 +2006,7 @@ def upload_tavern_character(img_path, _json): def check_tavern_character(img_path): + import gradio as gr img = open_image_safely(img_path) if img is None: @@ -1832,6 +2192,7 @@ def delete_user(name): def update_user_menu_after_deletion(idx): """Update user menu after a user is deleted""" + import gradio as gr users = get_available_users() if len(users) == 0: # Create a default user if none exist @@ -1864,93 +2225,13 @@ def handle_user_menu_change(state): def handle_save_user_click(name1): """Handle save user button click""" + import gradio as gr return [ name1, gr.update(visible=True) ] -def jinja_template_from_old_format(params, verbose=False): - MASTER_TEMPLATE = """ -{%- set ns = namespace(found=false) -%} -{%- for message in messages -%} - {%- if message['role'] == 'system' -%} - {%- set ns.found = true -%} - {%- endif -%} -{%- endfor -%} -{%- if not ns.found -%} - {{- '<|PRE-SYSTEM|>' + '<|SYSTEM-MESSAGE|>' + '<|POST-SYSTEM|>' -}} -{%- endif %} -{%- for message in messages %} - {%- if message['role'] == 'system' -%} - {{- '<|PRE-SYSTEM|>' + message['content'] + '<|POST-SYSTEM|>' -}} - {%- else -%} - {%- if message['role'] == 'user' -%} - {{-'<|PRE-USER|>' + message['content'] + '<|POST-USER|>'-}} - {%- else -%} - {{-'<|PRE-ASSISTANT|>' + message['content'] + '<|POST-ASSISTANT|>' -}} - {%- endif -%} - {%- endif -%} -{%- endfor -%} -{%- if add_generation_prompt -%} - {{-'<|PRE-ASSISTANT-GENERATE|>'-}} -{%- endif -%} -""" - - if 'context' in params and '<|system-message|>' in params['context']: - pre_system = params['context'].split('<|system-message|>')[0] - post_system = params['context'].split('<|system-message|>')[1] - else: - pre_system = '' - post_system = '' - - pre_user = params['turn_template'].split('<|user-message|>')[0].replace('<|user|>', params['user']) - post_user = params['turn_template'].split('<|user-message|>')[1].split('<|bot|>')[0] - - pre_assistant = '<|bot|>' + params['turn_template'].split('<|bot-message|>')[0].split('<|bot|>')[1] - pre_assistant = pre_assistant.replace('<|bot|>', params['bot']) - post_assistant = params['turn_template'].split('<|bot-message|>')[1] - - def preprocess(string): - return string.replace('\n', '\\n').replace('\'', '\\\'') - - pre_system = preprocess(pre_system) - post_system = preprocess(post_system) - pre_user = preprocess(pre_user) - post_user = preprocess(post_user) - pre_assistant = preprocess(pre_assistant) - post_assistant = preprocess(post_assistant) - - if verbose: - print( - '\n', - repr(pre_system) + '\n', - repr(post_system) + '\n', - repr(pre_user) + '\n', - repr(post_user) + '\n', - repr(pre_assistant) + '\n', - repr(post_assistant) + '\n', - ) - - result = MASTER_TEMPLATE - if 'system_message' in params: - result = result.replace('<|SYSTEM-MESSAGE|>', preprocess(params['system_message'])) - else: - result = result.replace('<|SYSTEM-MESSAGE|>', '') - - result = result.replace('<|PRE-SYSTEM|>', pre_system) - result = result.replace('<|POST-SYSTEM|>', post_system) - result = result.replace('<|PRE-USER|>', pre_user) - result = result.replace('<|POST-USER|>', post_user) - result = result.replace('<|PRE-ASSISTANT|>', pre_assistant) - result = result.replace('<|PRE-ASSISTANT-GENERATE|>', pre_assistant.rstrip(' ')) - result = result.replace('<|POST-ASSISTANT|>', post_assistant) - - result = result.strip() - - return result - - def my_yaml_output(data): ''' pyyaml is very inconsistent with multiline strings. @@ -2002,6 +2283,7 @@ def handle_unique_id_select(state): def handle_start_new_chat_click(state): + import gradio as gr history = start_new_chat(state) histories = find_all_histories_with_first_prompts(state) html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) @@ -2016,10 +2298,29 @@ def handle_start_new_chat_click(state): return [history, html, past_chats_update] +def handle_start_incognito_chat_click(state): + import gradio as gr + unique_id = 'incognito-' + datetime.now().strftime('%Y%m%d-%H-%M-%S') + history = start_new_chat(state, unique_id=unique_id) + html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) + + convert_to_markdown.cache_clear() + + histories = find_all_histories_with_first_prompts(state) + past_chats_update = gr.update(choices=histories, value=unique_id) + + return [history, html, past_chats_update] + + def handle_delete_chat_confirm_click(state): filtered_histories = find_all_histories_with_first_prompts(state) filtered_ids = [h[1] for h in filtered_histories] - index = str(filtered_ids.index(state['unique_id'])) + + if state['unique_id'] not in filtered_ids: + # Incognito or unknown chat — just load the most recent saved chat + index = '0' + else: + index = str(filtered_ids.index(state['unique_id'])) delete_history(state['unique_id'], state['character_menu'], state['mode']) history, unique_id = load_history_after_deletion(state, index) @@ -2027,16 +2328,11 @@ def handle_delete_chat_confirm_click(state): convert_to_markdown.cache_clear() - return [ - history, - html, - unique_id, - gr.update(visible=False), - gr.update(visible=True), - ] + return [history, html, unique_id] def handle_branch_chat_click(state): + import gradio as gr branch_from_index = state['branch_index'] if branch_from_index == -1: history = state['history'] @@ -2048,7 +2344,8 @@ def handle_branch_chat_click(state): if 'metadata' in history: history['metadata'] = {k: v for k, v in history['metadata'].items() if int(k.split('_')[-1]) <= branch_from_index} - new_unique_id = datetime.now().strftime('%Y%m%d-%H-%M-%S') + prefix = 'incognito-' if state['unique_id'] and state['unique_id'].startswith('incognito-') else '' + new_unique_id = prefix + datetime.now().strftime('%Y%m%d-%H-%M-%S') save_history(history, new_unique_id, state['character_menu'], state['mode']) histories = find_all_histories_with_first_prompts(state) @@ -2086,14 +2383,19 @@ def handle_edit_message_click(state): original_visible = history['visible'][message_index][role_idx] original_timestamp = history['metadata'][key].get('timestamp', get_current_timestamp()) - history['metadata'][key]["versions"] = [{ + version_entry = { "content": original_content, "visible_content": original_visible, "timestamp": original_timestamp - }] + } + ts = history['metadata'][key].get('tool_sequence') + if ts is not None: + version_entry['tool_sequence'] = ts + history['metadata'][key]["versions"] = [version_entry] history['internal'][message_index][role_idx] = apply_extensions('input', new_text, state, is_chat=True) history['visible'][message_index][role_idx] = html.escape(new_text) + history['metadata'][key].pop('tool_sequence', None) add_message_version(history, role, message_index, is_current=True) @@ -2138,6 +2440,14 @@ def handle_navigate_version_click(state): history['internal'][message_index][msg_content_idx] = version_to_load['content'] history['visible'][message_index][msg_content_idx] = version_to_load['visible_content'] metadata['current_version_index'] = new_idx + + # Restore per-version tool_sequence so follow-up prompts see consistent context + version_ts = version_to_load.get('tool_sequence') + if version_ts is not None: + metadata['tool_sequence'] = version_ts + else: + metadata.pop('tool_sequence', None) + update_message_metadata(history['metadata'], role, message_index, timestamp=version_to_load['timestamp']) # Redraw and save @@ -2148,6 +2458,7 @@ def handle_navigate_version_click(state): def handle_rename_chat_click(): + import gradio as gr return [ gr.update(value="My New Chat"), gr.update(visible=True), @@ -2155,6 +2466,14 @@ def handle_rename_chat_click(): def handle_rename_chat_confirm(rename_to, state): + import gradio as gr + + if state['unique_id'] and state['unique_id'].startswith('incognito-'): + return [ + gr.update(), + gr.update(visible=False), + ] + rename_history(state['unique_id'], rename_to, state['character_menu'], state['mode']) histories = find_all_histories_with_first_prompts(state) @@ -2165,11 +2484,13 @@ def handle_rename_chat_confirm(rename_to, state): def handle_search_chat_change(state): + import gradio as gr histories = find_all_histories_with_first_prompts(state) return gr.update(choices=histories) def handle_upload_chat_history(load_chat_history, state): + import gradio as gr history = start_new_chat(state) history = load_history_json(load_chat_history, history) save_history(history, state['unique_id'], state['character_menu'], state['mode']) @@ -2192,6 +2513,7 @@ def handle_upload_chat_history(load_chat_history, state): def handle_character_menu_change(state): + import gradio as gr name1, name2, picture, greeting, context = load_character(state['character_menu'], state['name1'], state['name2']) state['name1'] = name1 @@ -2244,6 +2566,7 @@ def handle_character_picture_change(picture_path): def handle_mode_change(state): + import gradio as gr history, loaded_unique_id = load_latest_history(state) histories = find_all_histories_with_first_prompts(state) @@ -2270,6 +2593,7 @@ def handle_mode_change(state): def handle_save_character_click(name2): + import gradio as gr return [ name2, gr.update(visible=True) @@ -2277,6 +2601,7 @@ def handle_save_character_click(name2): def handle_load_template_click(instruction_template): + from modules.models_settings import load_instruction_template output = load_instruction_template(instruction_template) return [ output, @@ -2285,6 +2610,7 @@ def handle_load_template_click(instruction_template): def handle_save_template_click(instruction_template_str): + import gradio as gr contents = generate_instruction_template_yaml(instruction_template_str) return [ "My Template.yaml", @@ -2295,6 +2621,7 @@ def handle_save_template_click(instruction_template_str): def handle_delete_template_click(template): + import gradio as gr return [ f"{template}.yaml", str(shared.user_data_dir / 'instruction-templates') + '/', @@ -2310,6 +2637,7 @@ def handle_your_picture_change(picture, state): def handle_send_instruction_click(state): + import gradio as gr state['mode'] = 'instruct' state['history'] = {'internal': [], 'visible': [], 'metadata': {}} @@ -2322,6 +2650,7 @@ def handle_send_instruction_click(state): def handle_send_chat_click(state): + import gradio as gr output = generate_chat_prompt("", state, _continue=True) if state["show_two_notebook_columns"]: diff --git a/modules/exllamav3.py b/modules/exllamav3.py index b4b76e21..1c682e49 100644 --- a/modules/exllamav3.py +++ b/modules/exllamav3.py @@ -1,3 +1,4 @@ +import math import queue import threading import traceback @@ -9,6 +10,7 @@ import torch from exllamav3 import Cache, Config, Generator, Model, Tokenizer from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant from exllamav3.generator import Job +from exllamav3.generator.filter import Filter from exllamav3.generator.sampler import ( CustomSampler, SS_AdaptiveP, @@ -36,6 +38,29 @@ except Exception: traceback.print_exc() +class LogitBiasFilter(Filter): + """Filter subclass that applies a static additive logit bias mask.""" + + def __init__(self, tokenizer, logit_bias_dict): + super().__init__(tokenizer=tokenizer, trigger_token=None, prefix_str=None, eos_after_completed=False) + self.logit_bias_dict = logit_bias_dict + self._mask = None + + def reset(self): pass + def accept_token(self, token): pass + def is_completed(self): return False + def use_background_worker(self): return False + + def get_next_logit_mask(self): + if self._mask is None: + self._mask = torch.zeros((1, self.vocab_size), dtype=self.logits_dtype) + for token_id_str, bias in self.logit_bias_dict.items(): + token_id = int(token_id_str) + if 0 <= token_id < self.vocab_size: + self._mask[0, token_id] = bias + return self._mask + + class ConcurrentGenerator: def __init__(self, generator): self.generator = generator @@ -53,7 +78,16 @@ class ConcurrentGenerator: if not self.job_queues: self.has_jobs.clear() continue - results = self.generator.iterate() + try: + results = self.generator.iterate() + except Exception: + logger.error("Exception in ConcurrentGenerator iterate loop:\n" + traceback.format_exc()) + for q in self.job_queues.values(): + q.put(None) + self.job_queues.clear() + self.generator.clear_queue() + self.has_jobs.clear() + continue for result in results: job = result["job"] q = self.job_queues.get(job) @@ -89,6 +123,10 @@ class Exllamav3Model: def __init__(self): pass + @property + def device(self) -> torch.device: + return torch.device(0) + @classmethod def from_pretrained(cls, path_to_model): path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model) @@ -149,8 +187,21 @@ class Exllamav3Model: load_params['tensor_p'] = True load_params['tp_backend'] = shared.args.tp_backend - model.load(**load_params) - tokenizer = Tokenizer.from_config(config) + # Load vision and draft before the main model so autosplit + # accounts for their VRAM usage. + + # Load vision model component (ExLlamaV3 native) + vision_model = None + if "vision_config" in config.config_dict: + logger.info("Vision component detected in model config. Attempting to load...") + try: + vision_model = Model.from_config(config, component="vision") + vision_model.load(progressbar=True) + logger.info("Vision model loaded successfully.") + except Exception as e: + logger.warning(f"Vision model loading failed (multimodal disabled): {e}") + else: + logger.info("No vision component in model config. Skipping multimodal setup.") # Initialize draft model for speculative decoding draft_model = None @@ -166,23 +217,8 @@ class Exllamav3Model: logger.warning(f"Draft model not found at {draft_path}, speculative decoding disabled.") else: draft_config = Config.from_directory(str(draft_path)) - - # Set context size for draft model with 256-multiple validation - if shared.args.ctx_size_draft > 0: - draft_max_tokens = shared.args.ctx_size_draft - else: - draft_max_tokens = shared.args.ctx_size - - # Validate draft model context size is a multiple of 256 - if draft_max_tokens % 256 != 0: - adjusted_draft_tokens = ((draft_max_tokens // 256) + 1) * 256 - logger.warning(f"Draft model max_num_tokens must be a multiple of 256. Adjusting from {draft_max_tokens} to {adjusted_draft_tokens}") - draft_max_tokens = adjusted_draft_tokens - - draft_config.max_seq_len = draft_max_tokens - draft_model = Model.from_config(draft_config) - draft_cache = Cache(draft_model, max_num_tokens=draft_max_tokens, layer_type=layer_type, **cache_kwargs) + draft_cache = Cache(draft_model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs) draft_load_params = {'progressbar': True} if split: @@ -191,18 +227,9 @@ class Exllamav3Model: draft_model.load(**draft_load_params) logger.info(f"Draft model loaded successfully. Max speculative tokens: {shared.args.draft_max}") - # Load vision model component (ExLlamaV3 native) - vision_model = None - if "vision_config" in config.config_dict: - logger.info("Vision component detected in model config. Attempting to load...") - try: - vision_model = Model.from_config(config, component="vision") - vision_model.load(progressbar=True) - logger.info("Vision model loaded successfully.") - except Exception as e: - logger.warning(f"Vision model loading failed (multimodal disabled): {e}") - else: - logger.info("No vision component in model config. Skipping multimodal setup.") + # Load main model last + model.load(**load_params) + tokenizer = Tokenizer.from_config(config) generator = Generator( model=model, @@ -385,11 +412,22 @@ class Exllamav3Model: else: max_new_tokens = state['max_new_tokens'] - # Get stop conditions + # Use full EOS token list from config (may contain multiple IDs) stop_conditions = [] if not state['ban_eos_token']: - if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None: - stop_conditions.append(self.tokenizer.eos_token_id) + for eos_id in self.config.eos_token_id_list: + if eos_id is not None: + stop_conditions.append(eos_id) + + # Build filters for logit_bias (OpenAI API) + filters = [] + logit_bias = state.get('logit_bias') + if logit_bias: + filters.append(LogitBiasFilter(self.tokenizer, logit_bias)) + + # Logprobs support (OpenAI API) + logprobs = state.get('logprobs', 0) or 0 + return_top_tokens = logprobs if logprobs > 0 else 0 seed = state.get('seed', -1) job = Job( @@ -400,11 +438,15 @@ class Exllamav3Model: sampler=sampler, seed=seed if seed >= 0 else None, stop_conditions=stop_conditions if stop_conditions else None, + filters=filters if filters else None, + return_top_tokens=return_top_tokens, + return_probs=return_top_tokens > 0, ) # Stream generation response_text = "" stop_event = state.get('stop_event') + self.last_completion_probabilities = [] result_queue = self.parallel_generator.submit(job) try: @@ -416,14 +458,41 @@ class Exllamav3Model: except queue.Empty: continue if result is None or result.get("eos"): + # Capture logprobs from the final eos result too + if result is not None and return_top_tokens > 0: + self._capture_logprobs(result) break chunk = result.get("text", "") + + # Capture logprobs from streaming results + if return_top_tokens > 0: + self._capture_logprobs(result) + if chunk: response_text += chunk yield response_text finally: self.parallel_generator.cancel(job) + def _capture_logprobs(self, result): + """Convert ExLlamav3 top-k token data to the shared logprobs format.""" + top_k_tokens = result.get("top_k_tokens") + top_k_probs = result.get("top_k_probs") + if top_k_tokens is None or top_k_probs is None: + return + + id_to_piece = self.tokenizer.get_id_to_piece_list(True) + # top_k_tokens shape: (batch, seq_len, k), top_k_probs same + for seq_idx in range(top_k_tokens.shape[1]): + entry = {"top_logprobs": []} + for k_idx in range(top_k_tokens.shape[2]): + token_id = top_k_tokens[0, seq_idx, k_idx].item() + prob = top_k_probs[0, seq_idx, k_idx].item() + token_str = id_to_piece[token_id] if token_id < len(id_to_piece) else f"<{token_id}>" + logprob = math.log(prob) if prob > 0 else float("-inf") + entry["top_logprobs"].append({"token": token_str, "logprob": logprob}) + self.last_completion_probabilities.append(entry) + def generate(self, prompt, state): output = "" for chunk in self.generate_with_streaming(prompt, state): diff --git a/modules/exllamav3_hf.py b/modules/exllamav3_hf.py index b4b6ad20..d3c1cb90 100644 --- a/modules/exllamav3_hf.py +++ b/modules/exllamav3_hf.py @@ -201,19 +201,23 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin): } ).to(input_ids.device).float() else: - # When processing with labels, handle as a complete sequence - # Process in chunks if the number of tokens is large + # Labels path: use cache for cross-chunk attention. tokens_to_process = seq_tensor all_logits = None + current_len = 0 for i in range(0, tokens_to_process.shape[0], max_chunk_size): chunk = tokens_to_process[i:i + max_chunk_size] chunk_logits = self.ex_model.forward( input_ids=chunk.view(1, -1), params={ - "attn_mode": "flash_attn_nc", + "attn_mode": "flash_attn", + "cache": ex_cache, + "past_len": current_len, + "batch_shape": (1, self.max_tokens), } ).float() + current_len += chunk.shape[0] if all_logits is None: all_logits = chunk_logits diff --git a/modules/extensions.py b/modules/extensions.py index dd327882..e58a9a4c 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -6,8 +6,6 @@ from functools import partial from inspect import signature from pathlib import Path -import gradio as gr - import modules.shared as shared from modules.logging_colors import logger @@ -214,6 +212,7 @@ def _apply_custom_js(): def create_extensions_block(): + import gradio as gr to_display = [] for extension, name in iterator(): if hasattr(extension, "ui") and not (hasattr(extension, 'params') and extension.params.get('is_tab', False)): @@ -228,6 +227,7 @@ def create_extensions_block(): def create_extensions_tabs(): + import gradio as gr for extension, name in iterator(): if hasattr(extension, "ui") and (hasattr(extension, 'params') and extension.params.get('is_tab', False)): display_name = getattr(extension, 'params', {}).get('display_name', name) diff --git a/modules/html_generator.py b/modules/html_generator.py index 472a9ea0..8f3f261f 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -10,6 +10,7 @@ import markdown from PIL import Image, ImageOps from modules import shared +from modules.reasoning import extract_reasoning from modules.sane_markdown_lists import SaneListExtension from modules.utils import get_available_chat_styles @@ -108,69 +109,41 @@ def replace_blockquote(m): return m.group().replace('\n', '\n> ').replace('\\begin{blockquote}', '').replace('\\end{blockquote}', '') -# Thinking block format definitions: (start_tag, end_tag, content_start_tag) -# Use None for start_tag to match from beginning (end-only formats should be listed last) -THINKING_FORMATS = [ - ('', '', None), - ('<|channel|>analysis<|message|>', '<|end|>', '<|start|>assistant<|channel|>final<|message|>'), - ('', '', None), - ('<|think|>', '<|end|>', '<|content|>'), # Solar Open - ('Thinking Process:', '', None), # Qwen3.5 verbose thinking outside tags - (None, '', None), # End-only variant (e.g., Qwen3-next) -] - - def extract_thinking_block(string): - """Extract thinking blocks from the beginning of a string.""" - if not string: - return None, string - - for start_tag, end_tag, content_tag in THINKING_FORMATS: - end_esc = html.escape(end_tag) - content_esc = html.escape(content_tag) if content_tag else None - - if start_tag is None: - # End-only format: require end tag, start from beginning - end_pos = string.find(end_esc) - if end_pos == -1: - continue - thought_start = 0 - else: - # Normal format: require start tag - start_esc = html.escape(start_tag) - start_pos = string.find(start_esc) - if start_pos == -1: - continue - thought_start = start_pos + len(start_esc) - end_pos = string.find(end_esc, thought_start) - - if end_pos == -1: - # End tag missing - check if content tag can serve as fallback - if content_esc: - content_pos = string.find(content_esc, thought_start) - if content_pos != -1: - thought_end = content_pos - content_start = content_pos + len(content_esc) - else: - thought_end = len(string) - content_start = len(string) - else: - thought_end = len(string) - content_start = len(string) - else: - thought_end = end_pos - if content_esc: - content_pos = string.find(content_esc, end_pos) - content_start = content_pos + len(content_esc) if content_pos != -1 else end_pos + len(end_esc) - else: - content_start = end_pos + len(end_esc) - - return string[thought_start:thought_end], string[content_start:] - - return None, string + """Extract thinking blocks from the beginning of an HTML-escaped string.""" + return extract_reasoning(string, html_escaped=True) -def build_thinking_block(thinking_content, message_id, has_remaining_content): + +def build_tool_call_block(header, body, message_id, index): + """Build HTML for a tool call accordion block.""" + block_id = f"tool-call-{message_id}-{index}" + + if body == '...': + # Pending placeholder — no expandable body, just title with ellipsis + return f''' +
+ + {tool_svg_small} + {html.escape(header)} ... + +
+ ''' + + # Build a plain
 directly to avoid highlight.js auto-detection
+    escaped_body = html.escape(body)
+    return f'''
+    
+ + {tool_svg_small} + {html.escape(header)} + +
{escaped_body}
+
+ ''' + + +def build_thinking_block(thinking_content, message_id, has_remaining_content, thinking_index=0): """Build HTML for a thinking block.""" if thinking_content is None: return None @@ -179,7 +152,7 @@ def build_thinking_block(thinking_content, message_id, has_remaining_content): thinking_html = process_markdown_content(thinking_content) # Generate unique ID for the thinking block - block_id = f"thinking-{message_id}-0" + block_id = f"thinking-{message_id}-{thinking_index}" # Check if thinking is complete or still in progress is_streaming = not has_remaining_content @@ -344,6 +317,9 @@ def process_markdown_content(string): # Unescape backslashes html_output = html_output.replace('\\\\', '\\') + # Wrap tables in a scrollable div + html_output = html_output.replace('', '
').replace('
', '') + return html_output @@ -360,24 +336,66 @@ def convert_to_markdown(string, message_id=None): if message_id is None: message_id = "unknown" - # Extract different components from the string - thinking_content, remaining_content = extract_thinking_block(string) + # Find tool call blocks by position, then process the text segments + # between them using extract_thinking_block (which supports all + # THINKING_FORMATS, including end-only variants like Qwen's). + tool_call_pattern = re.compile(r'(.*?)\n(.*?)\n', re.DOTALL) + tool_calls = list(tool_call_pattern.finditer(string)) - # Build individual HTML blocks - blocks = [] + if not tool_calls: + # No tool calls — use original single-pass extraction + thinking_content, remaining_content = extract_thinking_block(string) + blocks = [] + thinking_html = build_thinking_block(thinking_content, message_id, bool(remaining_content)) + if thinking_html: + blocks.append(thinking_html) - # Add thinking block if present - thinking_html = build_thinking_block(thinking_content, message_id, bool(remaining_content)) - if thinking_html: - blocks.append(thinking_html) + main_html = build_main_content_block(remaining_content) + if main_html: + blocks.append(main_html) - # Add main content block - main_html = build_main_content_block(remaining_content) - if main_html: - blocks.append(main_html) + return ''.join(blocks) - # Assemble all blocks into final HTML - return ''.join(blocks) + # Split string into text segments around tool_call blocks and + # run extract_thinking_block on each segment for full format support. + html_parts = [] + last_end = 0 + tool_idx = 0 + think_idx = 0 + + def process_text_segment(text, is_last_segment): + """Process a text segment between tool_call blocks for thinking content.""" + nonlocal think_idx + if not text.strip(): + return + + while text.strip(): + thinking_content, remaining = extract_thinking_block(text) + if thinking_content is None: + break + has_remaining = bool(remaining.strip()) or not is_last_segment + html_parts.append(build_thinking_block(thinking_content, message_id, has_remaining, think_idx)) + think_idx += 1 + text = remaining + + if text.strip(): + html_parts.append(process_markdown_content(text)) + + for tc in tool_calls: + # Process text before this tool_call + process_text_segment(string[last_end:tc.start()], is_last_segment=False) + + # Add tool call accordion + header = tc.group(1).strip() + body = tc.group(2).strip() + html_parts.append(build_tool_call_block(header, body, message_id, tool_idx)) + tool_idx += 1 + last_end = tc.end() + + # Process text after the last tool_call + process_text_segment(string[last_end:], is_last_segment=True) + + return ''.join(html_parts) def convert_to_markdown_wrapped(string, message_id=None, use_cache=True): @@ -435,6 +453,7 @@ branch_svg = '''''' info_svg = '''''' info_svg_small = '''''' +tool_svg_small = '''''' attachment_svg = '''''' copy_button = f'' diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index 6f7cbd20..c3a8d105 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -36,6 +36,7 @@ class LlamaServer: self.process = None self.session = requests.Session() self.vocabulary_size = None + self.n_ctx = None self.bos_token = "" self.last_prompt_token_count = 0 @@ -133,9 +134,20 @@ class LlamaServer: payload["samplers"] = filtered_samplers + logit_bias = [] if state['custom_token_bans']: - to_ban = [[int(token_id), False] for token_id in state['custom_token_bans'].split(',')] - payload["logit_bias"] = to_ban + logit_bias.extend([[int(token_id.strip()), False] for token_id in state['custom_token_bans'].split(',') if token_id.strip()]) + + if state.get('logit_bias'): + for token_id_str, bias in state['logit_bias'].items(): + logit_bias.append([int(token_id_str), bias]) + + if logit_bias: + payload["logit_bias"] = logit_bias + + n_probs = state.get('logprobs', 0) + if n_probs and n_probs > 0: + payload["n_probs"] = n_probs return payload @@ -215,6 +227,7 @@ class LlamaServer: response.raise_for_status() # Raise an exception for HTTP errors full_text = "" + self.last_completion_probabilities = [] # Process the streaming response stop_event = state.get('stop_event') @@ -240,6 +253,10 @@ class LlamaServer: full_text += data['content'] yield full_text + # Capture logprobs if present + if 'completion_probabilities' in data: + self.last_completion_probabilities.extend(data['completion_probabilities']) + # Check if generation is complete if data.get('stop', False): break @@ -304,12 +321,17 @@ class LlamaServer: self.vocabulary_size = model_info["meta"]["n_vocab"] def _get_bos_token(self): - """Get and store the model's BOS token.""" + """Get and store the model's BOS token and context size.""" url = f"http://127.0.0.1:{self.port}/props" response = self.session.get(url).json() if "bos_token" in response: self.bos_token = response["bos_token"] + # Get actual n_ctx from the server (important when --fit auto-selects it) + n_ctx = response.get("default_generation_settings", {}).get("n_ctx") + if n_ctx: + self.n_ctx = n_ctx + def _is_port_available(self, port): """Check if a port is available for use.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -349,11 +371,14 @@ class LlamaServer: if shared.args.ctx_size > 0: cmd += ["--ctx-size", str(shared.args.ctx_size)] + elif shared.args.gpu_layers >= 0: + cmd += ["--ctx-size", "8192"] if shared.args.gpu_layers >= 0: cmd += ["--gpu-layers", str(shared.args.gpu_layers), "--fit", "off"] else: cmd += ["--fit", "on"] + cmd += ["--fit-ctx", "8192"] if shared.args.fit_target: cmd += ["--fit-target", shared.args.fit_target] @@ -379,10 +404,6 @@ class LlamaServer: if shared.args.cache_type != "fp16" and shared.args.cache_type in llamacpp_valid_cache_types: cmd += ["--cache-type-k", shared.args.cache_type, "--cache-type-v", shared.args.cache_type] cache_type = shared.args.cache_type - if shared.args.compress_pos_emb != 1: - cmd += ["--rope-freq-scale", str(1.0 / shared.args.compress_pos_emb)] - if shared.args.rope_freq_base > 0: - cmd += ["--rope-freq-base", str(shared.args.rope_freq_base)] if shared.args.mmproj not in [None, 'None']: path = Path(shared.args.mmproj) if not path.exists(): @@ -455,7 +476,7 @@ class LlamaServer: print() gpu_layers_str = "auto" if shared.args.gpu_layers < 0 else str(shared.args.gpu_layers) - ctx_size_str = "auto" if shared.args.ctx_size == 0 else str(shared.args.ctx_size) + ctx_size_str = "auto" if shared.args.ctx_size == 0 and shared.args.gpu_layers < 0 else str(shared.args.ctx_size or 8192) logger.info(f"Using gpu_layers={gpu_layers_str} | ctx_size={ctx_size_str} | cache_type={cache_type}") # Start the server with pipes for output self.process = subprocess.Popen( diff --git a/modules/loaders.py b/modules/loaders.py index 42a5ff1c..c90f2ebb 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -1,8 +1,6 @@ import functools from collections import OrderedDict -import gradio as gr - loaders_and_params = OrderedDict({ 'llama.cpp': [ 'gpu_layers', @@ -17,8 +15,6 @@ loaders_and_params = OrderedDict({ 'tensor_split', 'extra_flags', 'streaming_llm', - 'rope_freq_base', - 'compress_pos_emb', 'row_split', 'no_kv_offload', 'no_mmap', @@ -43,8 +39,6 @@ loaders_and_params = OrderedDict({ 'Transformers': [ 'gpu_split', 'cpu_memory', - 'alpha_value', - 'compress_pos_emb', 'compute_dtype', 'quant_type', 'load_in_8bit', @@ -71,7 +65,6 @@ loaders_and_params = OrderedDict({ 'gpu_split', 'model_draft', 'draft_max', - 'ctx_size_draft', 'speculative_decoding_accordion', 'enable_tp', 'tp_backend', @@ -208,6 +201,7 @@ loaders_samplers = { 'ban_eos_token', 'add_bos_token', 'enable_thinking', + 'reasoning_effort', 'seed', 'skip_special_tokens', }, @@ -244,6 +238,7 @@ loaders_samplers = { 'reasoning_effort', 'seed', 'sampler_priority', + 'custom_token_bans', 'dry_sequence_breakers', 'grammar_string', 'grammar_file_row', @@ -277,6 +272,7 @@ def list_all_samplers(): def blacklist_samplers(loader, dynamic_temperature): + import gradio as gr all_samplers = list_all_samplers() output = [] @@ -302,7 +298,58 @@ def get_all_params(): return sorted(all_params) +def list_model_elements(): + return [ + 'filter_by_loader', + 'loader', + 'cpu_memory', + 'gpu_layers', + 'fit_target', + 'cpu_moe', + 'threads', + 'threads_batch', + 'batch_size', + 'ubatch_size', + 'ctx_size', + 'cache_type', + 'tensor_split', + 'extra_flags', + 'streaming_llm', + 'gpu_split', + 'compute_dtype', + 'quant_type', + 'load_in_8bit', + 'load_in_4bit', + 'attn_implementation', + 'cpu', + 'disk', + 'row_split', + 'no_kv_offload', + 'no_mmap', + 'mlock', + 'numa', + 'parallel', + 'use_double_quant', + 'bf16', + 'enable_tp', + 'tp_backend', + 'cfg_cache', + 'no_use_fast', + 'model_draft', + 'draft_max', + 'gpu_layers_draft', + 'device_draft', + 'ctx_size_draft', + 'spec_type', + 'spec_ngram_size_n', + 'spec_ngram_size_m', + 'spec_ngram_min_hits', + 'mmproj', + ] + + def make_loader_params_visible(loader): + import gradio as gr params = [] all_params = get_all_params() if loader in loaders_and_params: diff --git a/modules/models.py b/modules/models.py index 48d68b0b..1d139b89 100644 --- a/modules/models.py +++ b/modules/models.py @@ -38,6 +38,9 @@ def load_model(model_name, loader=None): sampler_hijack.hijack_samplers() shared.args.loader = loader + if loader != 'llama.cpp' and shared.args.ctx_size == 0: + shared.args.ctx_size = 8192 + output = load_func_map[loader](model_name) if type(output) is tuple: model, tokenizer = output @@ -54,6 +57,8 @@ def load_model(model_name, loader=None): if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp': if shared.args.ctx_size > 0: shared.settings['truncation_length'] = shared.args.ctx_size + elif loader == 'llama.cpp' and hasattr(model, 'n_ctx') and model.n_ctx: + shared.settings['truncation_length'] = model.n_ctx shared.is_multimodal = False if loader.lower() in ('exllamav3', 'llama.cpp') and hasattr(model, 'is_multimodal'): diff --git a/modules/models_settings.py b/modules/models_settings.py index 472871ce..f3c9a986 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -4,10 +4,9 @@ import re from math import floor from pathlib import Path -import gradio as gr import yaml -from modules import chat, loaders, metadata_gguf, shared, ui +from modules import loaders, metadata_gguf, shared from modules.logging_colors import logger from modules.utils import resolve_model_path @@ -16,9 +15,6 @@ def get_fallback_settings(): return { 'bf16': False, 'ctx_size': 8192, - 'rope_freq_base': 0, - 'compress_pos_emb': 1, - 'alpha_value': 1, 'truncation_length': shared.settings['truncation_length'], 'truncation_length_info': shared.settings['truncation_length'], 'skip_special_tokens': shared.settings['skip_special_tokens'], @@ -68,14 +64,8 @@ def get_model_metadata(model): for k in metadata: if k.endswith('.context_length'): - model_settings['ctx_size'] = min(metadata[k], 8192) + model_settings['ctx_size'] = 0 model_settings['truncation_length_info'] = metadata[k] - elif k.endswith('rope.freq_base'): - model_settings['rope_freq_base'] = metadata[k] - elif k.endswith('rope.scale_linear'): - model_settings['compress_pos_emb'] = metadata[k] - elif k.endswith('rope.scaling.factor'): - model_settings['compress_pos_emb'] = metadata[k] elif k.endswith('.block_count'): model_settings['gpu_layers'] = -1 model_settings['max_gpu_layers'] = metadata[k] + 1 @@ -120,15 +110,6 @@ def get_model_metadata(model): model_settings['ctx_size'] = min(value, 8192) break - if 'rope_theta' in metadata: - model_settings['rope_freq_base'] = metadata['rope_theta'] - elif 'attn_config' in metadata and 'rope_theta' in metadata['attn_config']: - model_settings['rope_freq_base'] = metadata['attn_config']['rope_theta'] - - if 'rope_scaling' in metadata and isinstance(metadata['rope_scaling'], dict) and all(key in metadata['rope_scaling'] for key in ('type', 'factor')): - if metadata['rope_scaling']['type'] == 'linear': - model_settings['compress_pos_emb'] = metadata['rope_scaling']['factor'] - if 'torch_dtype' in metadata and metadata['torch_dtype'] == 'bfloat16': model_settings['bf16'] = True @@ -182,10 +163,6 @@ def get_model_metadata(model): if 'instruction_template' not in model_settings: model_settings['instruction_template'] = 'Alpaca' - # Ignore rope_freq_base if set to the default value - if 'rope_freq_base' in model_settings and model_settings['rope_freq_base'] == 10000: - model_settings.pop('rope_freq_base') - # Apply user settings from user_data/models/config-user.yaml settings = shared.user_config for pat in settings: @@ -199,7 +176,7 @@ def get_model_metadata(model): # Load instruction template if defined by name rather than by value if model_settings['instruction_template'] != 'Custom (obtained from model metadata)': - model_settings['instruction_template_str'] = chat.load_instruction_template(model_settings['instruction_template']) + model_settings['instruction_template_str'] = load_instruction_template(model_settings['instruction_template']) return model_settings @@ -228,7 +205,7 @@ def update_model_parameters(state, initial=False): ''' UI: update the command-line arguments based on the interface values ''' - elements = ui.list_model_elements() # the names of the parameters + elements = loaders.list_model_elements() # the names of the parameters for i, element in enumerate(elements): if element not in state: @@ -248,6 +225,7 @@ def apply_model_settings_to_state(model, state): ''' UI: update the state variable with the model settings ''' + import gradio as gr model_settings = get_model_metadata(model) if 'loader' in model_settings: loader = model_settings.pop('loader') @@ -290,7 +268,7 @@ def save_model_settings(model, state): if model_regex not in user_config: user_config[model_regex] = {} - for k in ui.list_model_elements(): + for k in loaders.list_model_elements(): if k == 'loader' or k in loaders.loaders_and_params[state['loader']]: user_config[model_regex][k] = state[k] @@ -419,3 +397,103 @@ def update_gpu_layers_and_vram(loader, model, gpu_layers, ctx_size, cache_type): vram_usage = estimate_vram(model, gpu_layers, ctx_size, cache_type) return f"
Estimated VRAM to load the model: {vram_usage:.0f} MiB
" + + +def load_instruction_template(template): + if template == 'None': + return '' + + for filepath in [shared.user_data_dir / 'instruction-templates' / f'{template}.yaml', shared.user_data_dir / 'instruction-templates' / 'Alpaca.yaml']: + if filepath.exists(): + break + else: + return '' + + with open(filepath, 'r', encoding='utf-8') as f: + file_contents = f.read() + data = yaml.safe_load(file_contents) + if 'instruction_template' in data: + return data['instruction_template'] + else: + return _jinja_template_from_old_format(data) + + +def _jinja_template_from_old_format(params, verbose=False): + MASTER_TEMPLATE = """ +{%- set ns = namespace(found=false) -%} +{%- for message in messages -%} + {%- if message['role'] == 'system' -%} + {%- set ns.found = true -%} + {%- endif -%} +{%- endfor -%} +{%- if not ns.found -%} + {{- '<|PRE-SYSTEM|>' + '<|SYSTEM-MESSAGE|>' + '<|POST-SYSTEM|>' -}} +{%- endif %} +{%- for message in messages %} + {%- if message['role'] == 'system' -%} + {{- '<|PRE-SYSTEM|>' + message['content'] + '<|POST-SYSTEM|>' -}} + {%- else -%} + {%- if message['role'] == 'user' -%} + {{-'<|PRE-USER|>' + message['content'] + '<|POST-USER|>'-}} + {%- else -%} + {{-'<|PRE-ASSISTANT|>' + message['content'] + '<|POST-ASSISTANT|>' -}} + {%- endif -%} + {%- endif -%} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{-'<|PRE-ASSISTANT-GENERATE|>'-}} +{%- endif -%} +""" + + if 'context' in params and '<|system-message|>' in params['context']: + pre_system = params['context'].split('<|system-message|>')[0] + post_system = params['context'].split('<|system-message|>')[1] + else: + pre_system = '' + post_system = '' + + pre_user = params['turn_template'].split('<|user-message|>')[0].replace('<|user|>', params['user']) + post_user = params['turn_template'].split('<|user-message|>')[1].split('<|bot|>')[0] + + pre_assistant = '<|bot|>' + params['turn_template'].split('<|bot-message|>')[0].split('<|bot|>')[1] + pre_assistant = pre_assistant.replace('<|bot|>', params['bot']) + post_assistant = params['turn_template'].split('<|bot-message|>')[1] + + def preprocess(string): + return string.replace('\n', '\\n').replace('\'', '\\\'') + + pre_system = preprocess(pre_system) + post_system = preprocess(post_system) + pre_user = preprocess(pre_user) + post_user = preprocess(post_user) + pre_assistant = preprocess(pre_assistant) + post_assistant = preprocess(post_assistant) + + if verbose: + print( + '\n', + repr(pre_system) + '\n', + repr(post_system) + '\n', + repr(pre_user) + '\n', + repr(post_user) + '\n', + repr(pre_assistant) + '\n', + repr(post_assistant) + '\n', + ) + + result = MASTER_TEMPLATE + if 'system_message' in params: + result = result.replace('<|SYSTEM-MESSAGE|>', preprocess(params['system_message'])) + else: + result = result.replace('<|SYSTEM-MESSAGE|>', '') + + result = result.replace('<|PRE-SYSTEM|>', pre_system) + result = result.replace('<|POST-SYSTEM|>', post_system) + result = result.replace('<|PRE-USER|>', pre_user) + result = result.replace('<|POST-USER|>', post_user) + result = result.replace('<|PRE-ASSISTANT|>', pre_assistant) + result = result.replace('<|PRE-ASSISTANT-GENERATE|>', pre_assistant.rstrip(' ')) + result = result.replace('<|POST-ASSISTANT|>', post_assistant) + + result = result.strip() + + return result diff --git a/modules/presets.py b/modules/presets.py index b53195ee..560e0b77 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -16,9 +16,10 @@ default_preset_values = { 'dynatemp_exponent': 1, 'smoothing_factor': 0, 'smoothing_curve': 1, - 'min_p': 0, 'top_p': 1, 'top_k': 0, + 'min_p': 0, + 'top_n_sigma': 0, 'typical_p': 1, 'xtc_threshold': 0.1, 'xtc_probability': 0, @@ -26,7 +27,6 @@ default_preset_values = { 'eta_cutoff': 0, 'tfs': 1, 'top_a': 0, - 'top_n_sigma': 0, 'adaptive_target': 0, 'adaptive_decay': 0.9, 'dry_multiplier': 0, diff --git a/modules/reasoning.py b/modules/reasoning.py new file mode 100644 index 00000000..bc61aab3 --- /dev/null +++ b/modules/reasoning.py @@ -0,0 +1,94 @@ +import html as html_module + +# Thinking block format definitions: (start_tag, end_tag, content_start_tag) +# Use None for start_tag to match from beginning (end-only formats should be listed last) +THINKING_FORMATS = [ + ('', '', None), + ('<|channel|>analysis<|message|>', '<|end|>', '<|channel|>final<|message|>'), + ('<|channel|>commentary<|message|>', '<|end|>', '<|channel|>final<|message|>'), + ('', '', None), + ('<|think|>', '<|end|>', '<|content|>'), # Solar Open + # ('Thinking Process:', '', None), # Qwen3.5 verbose thinking outside tags -- removed: too prone to false positives in streaming + (None, '', None), # End-only variant (e.g., Qwen3-next) +] + + +def extract_reasoning(text, html_escaped=False): + """Extract reasoning/thinking blocks from the beginning of a string. + + When html_escaped=True, tags are HTML-escaped before searching + (for use on already-escaped UI strings). + + Returns (reasoning_content, final_content) where reasoning_content is + None if no thinking block is found. + """ + if not text: + return None, text + + esc = html_module.escape if html_escaped else lambda s: s + + for start_tag, end_tag, content_tag in THINKING_FORMATS: + end_esc = esc(end_tag) + content_esc = esc(content_tag) if content_tag else None + + if start_tag is None: + # End-only format: require end tag, start from beginning + end_pos = text.find(end_esc) + if end_pos == -1: + continue + thought_start = 0 + else: + # Normal format: require start tag + start_esc = esc(start_tag) + start_pos = text.find(start_esc) + if start_pos == -1: + # During streaming, the start tag may be arriving partially. + # If the text is a prefix of a start tag, return empty content + # to prevent the partial tag from leaking. + stripped = text.strip() + if stripped and start_esc.startswith(stripped): + return '', '' + continue + thought_start = start_pos + len(start_esc) + end_pos = text.find(end_esc, thought_start) + + if end_pos == -1: + # End tag missing - check if content tag can serve as fallback + if content_esc: + content_pos = text.find(content_esc, thought_start) + if content_pos != -1: + thought_end = content_pos + content_start = content_pos + len(content_esc) + else: + thought_end = len(text) + content_start = len(text) + else: + thought_end = len(text) + content_start = len(text) + else: + thought_end = end_pos + if content_esc: + content_pos = text.find(content_esc, end_pos) + if content_pos != -1: + content_start = content_pos + len(content_esc) + else: + # Content tag expected but not yet present (e.g. partial + # streaming) — suppress intermediate tags between end_tag + # and content_tag so they don't leak as content. + content_start = len(text) + else: + content_start = end_pos + len(end_esc) + + return text[thought_start:thought_end], text[content_start:] + + # Handle standalone GPT-OSS final channel marker without a preceding + # analysis/commentary block (the model skipped thinking entirely). + for marker in ['<|start|>assistant<|channel|>final<|message|>', '<|channel|>final<|message|>']: + marker_esc = esc(marker) + pos = text.find(marker_esc) + if pos != -1: + before = text[:pos].strip() + after = text[pos + len(marker_esc):] + return (before if before else None), after + + return None, text diff --git a/modules/shared.py b/modules/shared.py index bc7ea8ba..329114bb 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -47,7 +47,7 @@ parser = argparse.ArgumentParser(description="Text Generation Web UI", conflict_ # Basic settings group = parser.add_argument_group('Basic settings') group.add_argument('--user-data-dir', type=str, default=str(user_data_dir), help='Path to the user data directory. Default: auto-detected.') -group.add_argument('--multi-user', action='store_true', help='Multi-user mode. Chat histories are not saved or automatically loaded. Warning: this is likely not safe for sharing publicly.') +group.add_argument('--multi-user', action='store_true', help='Multi-user mode. Chat histories are not saved or automatically loaded. Best suited for small trusted teams.') group.add_argument('--model', type=str, help='Name of the model to load by default.') group.add_argument('--lora', type=str, nargs='+', help='The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.') group.add_argument('--model-dir', type=str, default=str(user_data_dir / 'models'), help='Path to directory with all the models.') @@ -76,7 +76,7 @@ group.add_argument('--loader', type=str, help='Choose the model loader manually, # Cache group = parser.add_argument_group('Context and cache') -group.add_argument('--ctx-size', '--n_ctx', '--max_seq_len', type=int, default=8192, metavar='N', help='Context size in tokens. llama.cpp: 0 = auto if gpu-layers is also -1.') +group.add_argument('--ctx-size', '--n_ctx', '--max_seq_len', type=int, default=0, metavar='N', help='Context size in tokens. 0 = auto for llama.cpp (requires gpu-layers=-1), 8192 for other loaders.') group.add_argument('--cache-type', '--cache_type', type=str, default='fp16', metavar='N', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8).') # Speculative decoding @@ -108,7 +108,7 @@ group.add_argument('--threads', type=int, default=0, help='Number of threads to group.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.') group.add_argument('--numa', action='store_true', help='Activate NUMA task allocation for llama.cpp.') group.add_argument('--parallel', type=int, default=1, help='Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set ctx_size to 32768.') -group.add_argument('--fit-target', type=str, default='1024', help='Target VRAM margin per device for auto GPU layers, comma-separated list of values in MiB. A single value is broadcast across all devices. Default: 1024.') +group.add_argument('--fit-target', type=str, default='512', help='Target VRAM margin per device for auto GPU layers, comma-separated list of values in MiB. A single value is broadcast across all devices.') group.add_argument('--extra-flags', type=str, default=None, help='Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"') # Transformers/Accelerate @@ -139,12 +139,6 @@ group.add_argument('--enable-tp', '--enable_tp', action='store_true', help='Enab group.add_argument('--tp-backend', type=str, default='native', help='The backend for tensor parallelism. Valid options: native, nccl. Default: native.') group.add_argument('--cfg-cache', action='store_true', help='Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.') -# RoPE -group = parser.add_argument_group('RoPE') -group.add_argument('--alpha_value', type=float, default=1, help='Positional embeddings alpha factor for NTK RoPE scaling. Use either this or compress_pos_emb, not both.') -group.add_argument('--rope_freq_base', type=int, default=0, help='If greater than 0, will be used instead of alpha_value. Those two are related by rope_freq_base = 10000 * alpha_value ^ (64 / 63).') -group.add_argument('--compress_pos_emb', type=int, default=1, help="Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.") - # Gradio group = parser.add_argument_group('Gradio') group.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.') @@ -163,7 +157,7 @@ group.add_argument('--portable', action='store_true', help='Hide features not av # API group = parser.add_argument_group('API') group.add_argument('--api', action='store_true', help='Enable the API extension.') -group.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.') +group.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudflare.') group.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None) group.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.') group.add_argument('--api-key', type=str, default='', help='API authentication key.') @@ -181,9 +175,10 @@ group.add_argument('--dynatemp-high', type=float, default=_d['dynatemp_high'], m group.add_argument('--dynatemp-exponent', type=float, default=_d['dynatemp_exponent'], metavar='N', help='Dynamic temperature exponent') group.add_argument('--smoothing-factor', type=float, default=_d['smoothing_factor'], metavar='N', help='Smoothing factor') group.add_argument('--smoothing-curve', type=float, default=_d['smoothing_curve'], metavar='N', help='Smoothing curve') -group.add_argument('--min-p', type=float, default=_d['min_p'], metavar='N', help='Min P') group.add_argument('--top-p', type=float, default=_d['top_p'], metavar='N', help='Top P') group.add_argument('--top-k', type=int, default=_d['top_k'], metavar='N', help='Top K') +group.add_argument('--min-p', type=float, default=_d['min_p'], metavar='N', help='Min P') +group.add_argument('--top-n-sigma', type=float, default=_d['top_n_sigma'], metavar='N', help='Top N Sigma') group.add_argument('--typical-p', type=float, default=_d['typical_p'], metavar='N', help='Typical P') group.add_argument('--xtc-threshold', type=float, default=_d['xtc_threshold'], metavar='N', help='XTC threshold') group.add_argument('--xtc-probability', type=float, default=_d['xtc_probability'], metavar='N', help='XTC probability') @@ -191,7 +186,6 @@ group.add_argument('--epsilon-cutoff', type=float, default=_d['epsilon_cutoff'], group.add_argument('--eta-cutoff', type=float, default=_d['eta_cutoff'], metavar='N', help='Eta cutoff') group.add_argument('--tfs', type=float, default=_d['tfs'], metavar='N', help='TFS') group.add_argument('--top-a', type=float, default=_d['top_a'], metavar='N', help='Top A') -group.add_argument('--top-n-sigma', type=float, default=_d['top_n_sigma'], metavar='N', help='Top N Sigma') group.add_argument('--adaptive-target', type=float, default=_d['adaptive_target'], metavar='N', help='Adaptive target') group.add_argument('--adaptive-decay', type=float, default=_d['adaptive_decay'], metavar='N', help='Adaptive decay') group.add_argument('--dry-multiplier', type=float, default=_d['dry_multiplier'], metavar='N', help='DRY multiplier') @@ -263,8 +257,9 @@ settings = { 'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>". Reply directly, without starting the reply with the character name.\n\n<|prompt|>', 'enable_web_search': False, 'web_search_pages': 3, + 'selected_tools': [], 'prompt-notebook': '', - 'preset': 'Qwen3 - Thinking' if (user_data_dir / 'presets/Qwen3 - Thinking.yaml').exists() else None, + 'preset': 'Top-P' if (user_data_dir / 'presets/Top-P.yaml').exists() else None, 'max_new_tokens': 512, 'max_new_tokens_min': 1, 'max_new_tokens_max': 4096, @@ -289,7 +284,7 @@ settings = { 'include_past_attachments': True, # Generation parameters - Curve shape - 'temperature': 0.6, + 'temperature': neutral_samplers['temperature'], 'dynatemp_low': neutral_samplers['dynatemp_low'], 'dynatemp_high': neutral_samplers['dynatemp_high'], 'dynatemp_exponent': neutral_samplers['dynatemp_exponent'], @@ -297,9 +292,10 @@ settings = { 'smoothing_curve': neutral_samplers['smoothing_curve'], # Generation parameters - Curve cutoff - 'min_p': neutral_samplers['min_p'], 'top_p': 0.95, - 'top_k': 20, + 'top_k': neutral_samplers['top_k'], + 'min_p': neutral_samplers['min_p'], + 'top_n_sigma': neutral_samplers['top_n_sigma'], 'typical_p': neutral_samplers['typical_p'], 'xtc_threshold': neutral_samplers['xtc_threshold'], 'xtc_probability': neutral_samplers['xtc_probability'], @@ -307,7 +303,6 @@ settings = { 'eta_cutoff': neutral_samplers['eta_cutoff'], 'tfs': neutral_samplers['tfs'], 'top_a': neutral_samplers['top_a'], - 'top_n_sigma': neutral_samplers['top_n_sigma'], 'adaptive_target': neutral_samplers['adaptive_target'], 'adaptive_decay': neutral_samplers['adaptive_decay'], @@ -347,7 +342,7 @@ settings = { 'greeting': 'How can I help you today?', 'custom_system_message': '', 'instruction_template_str': "{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if not ns.found -%}\n {{- '' + 'Below is an instruction that describes a task. Write a response that appropriately completes the request.' + '\\n\\n' -}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {{- '' + message['content'] + '\\n\\n' -}}\n {%- else -%}\n {%- if message['role'] == 'user' -%}\n {{-'### Instruction:\\n' + message['content'] + '\\n\\n'-}}\n {%- else -%}\n {{-'### Response:\\n' + message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{-'### Response:\\n'-}}\n{%- endif -%}", - 'chat_template_str': "{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {%- if message['content'] -%}\n {{- message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- if user_bio -%}\n {{- user_bio + '\\n\\n' -}}\n {%- endif -%}\n {%- else -%}\n {%- if message['role'] == 'user' -%}\n {{- name1 + ': ' + message['content'] + '\\n'-}}\n {%- else -%}\n {{- name2 + ': ' + message['content'] + '\\n' -}}\n {%- endif -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt %}\n {{- name2 + ':' -}}\n{%- endif %}", + 'chat_template_str': "{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {%- if message['content'] -%}\n {{- message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- if user_bio -%}\n {{- user_bio + '\\n\\n' -}}\n {%- endif -%}\n {%- elif message['role'] == 'tool' -%}\n {{- '[Tool result: ' + message['content'] + ']\\n' -}}\n {%- elif message['role'] == 'user' -%}\n {{- name1 + ': ' + message['content'] + '\\n'-}}\n {%- elif message['tool_calls'] is defined and message['tool_calls'] -%}\n {%- for tc in message['tool_calls'] -%}\n {{- '[Calling: ' + tc['function']['name'] + '(' + tc['function']['arguments'] + ')]\\n' -}}\n {%- endfor -%}\n {%- else -%}\n {{- name2 + ': ' + message['content'] + '\\n' -}}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt %}\n {{- name2 + ':' -}}\n{%- endif %}", # Extensions 'default_extensions': [], @@ -395,9 +390,16 @@ def do_cmd_flags_warnings(): if args.share: logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.") if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)): - logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.") - if args.multi_user: - logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.') + logger.warning("You are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.") + if args.multi_user: + logger.warning( + 'Multi-user mode is enabled. Known limitations:' + '\n- The Stop button stops generation for all users, not just you.' + '\n- Chat history is not saved and will be lost on page refresh.' + '\n- Only one user can generate at a time unless using a parallel-capable backend (e.g. llama.cpp with --parallel N for N > 1, or ExLlamaV3).' + '\n\nThis mode works best for small trusted teams.' + '\n\nDo not expose publicly. Grayed-out actions can easily be bypassed client-side.\n' + ) def apply_image_model_cli_overrides(): diff --git a/modules/text_generation.py b/modules/text_generation.py index c78afe3e..d487cd2f 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -78,10 +78,13 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap reply = '' is_stream = state['stream'] if len(all_stop_strings) > 0 and not state['stream']: + original_logits_processor = state.get('logits_processor') stop_event_ref = state.pop('stop_event', None) state = copy.deepcopy(state) if stop_event_ref is not None: state['stop_event'] = stop_event_ref + if original_logits_processor is not None: + state['logits_processor'] = original_logits_processor state['stream'] = True # Generate @@ -375,7 +378,7 @@ def generate_reply_HF(question, original_question, state, stopping_strings=None, generate_params['sampler_priority'] = [x.strip() for x in state['sampler_priority'].replace('\n', ',').split(',') if x.strip()] if state['custom_token_bans']: - to_ban = [int(x) for x in state['custom_token_bans'].split(',')] + to_ban = [int(x.strip()) for x in state['custom_token_bans'].split(',') if x.strip()] if len(to_ban) > 0: if generate_params.get('suppress_tokens', None): generate_params['suppress_tokens'] += to_ban diff --git a/modules/tool_parsing.py b/modules/tool_parsing.py new file mode 100644 index 00000000..0454e901 --- /dev/null +++ b/modules/tool_parsing.py @@ -0,0 +1,667 @@ +import json +import random +import re + + +def get_tool_call_id() -> str: + letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789" + b = [random.choice(letter_bytes) for _ in range(8)] + return "call_" + "".join(b).lower() + + +# All known opening markers for tool calls across model formats. +TOOL_CALL_OPENING_MARKERS = [ + '', + '', + '', + '<|tool_call_begin|>', + '<|tool_calls_section_begin|>', + '<|tool▁call▁begin|>', + '<|tool▁calls▁begin|>', + '[TOOL_CALLS]', + 'to=functions.', + '<|channel|>commentary', +] + + +def streaming_tool_buffer_check(text, markers=None, tool_names=None, check_bare_names=False): + ''' + Check whether streaming output should be withheld because it may + contain tool-call markup. + + Args: + text: Full accumulated internal text. + markers: Template-specific markers for partial-prefix matching. + If None, falls back to TOOL_CALL_OPENING_MARKERS. + tool_names: List of tool function names. + check_bare_names: Whether to do partial-prefix matching on tool + names (for models with unknown template format). + ''' + # Full marker found in text → buffer permanently. + # Always checks ALL known markers regardless of template (cheap safety net). + for marker in TOOL_CALL_OPENING_MARKERS: + if marker in text: + return True + + # Bare function-name full match: "get_weather{...}" or "get_weather {...}" + if tool_names: + for name in tool_names: + if name + '{' in text or name + ' {' in text: + return True + + # Partial-prefix matching: only for template-specific markers. + for marker in (markers if markers is not None else TOOL_CALL_OPENING_MARKERS): + for prefix_len in range(min(len(marker) - 1, len(text)), 0, -1): + if text.endswith(marker[:prefix_len]): + return True + + # Bare-name partial matching: only when template format is unknown. + if check_bare_names and tool_names: + for name in tool_names: + if text.endswith(name): + return True + for prefix_len in range(min(len(name) - 1, len(text)), 0, -1): + if text.endswith(name[:prefix_len]): + return True + + return False + + +def check_and_sanitize_tool_call_candidate(candidate_dict: dict, tool_names: list[str]): + # check if property 'function' exists and is a dictionary, otherwise adapt dict + if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str): + candidate_dict = {"type": "function", "function": candidate_dict} + if 'function' in candidate_dict and isinstance(candidate_dict['function'], str): + candidate_dict['name'] = candidate_dict['function'] + del candidate_dict['function'] + candidate_dict = {"type": "function", "function": candidate_dict} + if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict): + # check if 'name' exists within 'function' and is part of known tools + if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names: + candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value + # map property 'parameters' used by some older models to 'arguments' + if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]: + candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"] + del candidate_dict["function"]["parameters"] + return candidate_dict + return None + + +def _extract_balanced_json(text: str, start: int) -> str | None: + """Extract a balanced JSON object from text starting at the given position. + + Walks through the string tracking brace depth and string boundaries + to correctly handle arbitrary nesting levels. + """ + if start >= len(text) or text[start] != '{': + return None + depth = 0 + in_string = False + escape_next = False + for i in range(start, len(text)): + c = text[i] + if escape_next: + escape_next = False + continue + if c == '\\' and in_string: + escape_next = True + continue + if c == '"': + in_string = not in_string + continue + if in_string: + continue + if c == '{': + depth += 1 + elif c == '}': + depth -= 1 + if depth == 0: + return text[start:i + 1] + return None + + +def _parse_channel_tool_calls(answer: str, tool_names: list[str]): + """Parse channel-based tool calls used by GPT-OSS and similar models. + + Format: + <|start|>assistant to=functions.func_name<|channel|>commentary json<|message|>{"arg": "value"} + or: + <|channel|>commentary to=functions.func_name <|constrain|>json<|message|>{"arg": "value"} + """ + matches = [] + start_pos = None + # Pattern 1: to=functions.NAME before <|channel|> (GPT-OSS primary format) + # Pattern 2: to=functions.NAME after <|channel|> (alternative format) + patterns = [ + r'to=functions\.([^<\s]+)\s*<\|channel\|>[^<]*<\|message\|>', + r'<\|channel\|>\w+ to=functions\.([^<\s]+).*?<\|message\|>', + ] + for pattern in patterns: + for m in re.finditer(pattern, answer): + func_name = m.group(1).strip() + if func_name not in tool_names: + continue + json_str = _extract_balanced_json(answer, m.end()) + if json_str is None: + continue + try: + arguments = json.loads(json_str) + if start_pos is None: + prefix = answer.rfind('<|start|>assistant', 0, m.start()) + start_pos = prefix if prefix != -1 else m.start() + matches.append({ + "type": "function", + "function": { + "name": func_name, + "arguments": arguments + } + }) + except json.JSONDecodeError: + pass + if matches: + break + return matches, start_pos + + +def _parse_mistral_token_tool_calls(answer: str, tool_names: list[str]): + """Parse Mistral/Devstral-style tool calls with [TOOL_CALLS] and [ARGS] special tokens. + + Format: + [TOOL_CALLS]func_name[ARGS]{"arg": "value"} + """ + matches = [] + start_pos = None + for m in re.finditer( + r'\[TOOL_CALLS\]\s*(\S+?)\s*\[ARGS\]\s*', + answer + ): + func_name = m.group(1).strip() + if func_name not in tool_names: + continue + json_str = _extract_balanced_json(answer, m.end()) + if json_str is None: + continue + try: + arguments = json.loads(json_str) + if start_pos is None: + start_pos = m.start() + matches.append({ + "type": "function", + "function": { + "name": func_name, + "arguments": arguments + } + }) + except json.JSONDecodeError: + pass + return matches, start_pos + + +def _parse_bare_name_tool_calls(answer: str, tool_names: list[str]): + """Parse bare function-name style tool calls used by Mistral and similar models. + + Format: + functionName{"arg": "value"} + Multiple calls are concatenated directly or separated by whitespace. + """ + matches = [] + start_pos = None + # Match tool name followed by opening brace, then extract balanced JSON + escaped_names = [re.escape(name) for name in tool_names] + pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{' + for match in re.finditer(pattern, answer): + text = match.group(0) + name = None + for n in tool_names: + if text.startswith(n): + name = n + break + if not name: + continue + brace_start = match.end() - 1 + json_str = _extract_balanced_json(answer, brace_start) + if json_str is None: + continue + try: + arguments = json.loads(json_str) + if start_pos is None: + start_pos = match.start() + matches.append({ + "type": "function", + "function": { + "name": name, + "arguments": arguments + } + }) + except json.JSONDecodeError: + pass + return matches, start_pos + + +def _parse_xml_param_tool_calls(answer: str, tool_names: list[str]): + """Parse XML-parameter style tool calls used by Qwen3.5 and similar models. + + Format: + + + value + + + """ + matches = [] + start_pos = None + for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL): + tc_content = tc_match.group(1) + func_match = re.search(r']+)>', tc_content) + if not func_match: + continue + func_name = func_match.group(1).strip() + if func_name not in tool_names: + continue + arguments = {} + for param_match in re.finditer(r']+)>\s*(.*?)\s*', tc_content, re.DOTALL): + param_name = param_match.group(1).strip() + param_value = param_match.group(2).strip() + try: + param_value = json.loads(param_value) + except (json.JSONDecodeError, ValueError): + pass # keep as string + arguments[param_name] = param_value + if start_pos is None: + start_pos = tc_match.start() + matches.append({ + "type": "function", + "function": { + "name": func_name, + "arguments": arguments + } + }) + return matches, start_pos + + +def _parse_kimi_tool_calls(answer: str, tool_names: list[str]): + """Parse Kimi-K2-style tool calls using pipe-delimited tokens. + + Format: + <|tool_calls_section_begin|> + <|tool_call_begin|>functions.func_name:index<|tool_call_argument_begin|>{"arg": "value"}<|tool_call_end|> + <|tool_calls_section_end|> + """ + matches = [] + start_pos = None + for m in re.finditer( + r'<\|tool_call_begin\|>\s*(?:functions\.)?(\S+?)(?::\d+)?\s*<\|tool_call_argument_begin\|>\s*', + answer + ): + func_name = m.group(1).strip() + if func_name not in tool_names: + continue + json_str = _extract_balanced_json(answer, m.end()) + if json_str is None: + continue + try: + arguments = json.loads(json_str) + if start_pos is None: + # Check for section begin marker before the call marker + section = answer.rfind('<|tool_calls_section_begin|>', 0, m.start()) + start_pos = section if section != -1 else m.start() + matches.append({ + "type": "function", + "function": { + "name": func_name, + "arguments": arguments + } + }) + except json.JSONDecodeError: + pass + return matches, start_pos + + +def _parse_minimax_tool_calls(answer: str, tool_names: list[str]): + """Parse MiniMax-style tool calls using invoke/parameter XML tags. + + Format: + + + value + + + """ + matches = [] + start_pos = None + for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL): + tc_content = tc_match.group(1) + # Split on to handle multiple parallel calls in one block + for invoke_match in re.finditer(r'(.*?)', tc_content, re.DOTALL): + func_name = invoke_match.group(1).strip() + if func_name not in tool_names: + continue + invoke_body = invoke_match.group(2) + arguments = {} + for param_match in re.finditer(r'\s*(.*?)\s*', invoke_body, re.DOTALL): + param_name = param_match.group(1).strip() + param_value = param_match.group(2).strip() + try: + param_value = json.loads(param_value) + except (json.JSONDecodeError, ValueError): + pass # keep as string + arguments[param_name] = param_value + if start_pos is None: + start_pos = tc_match.start() + matches.append({ + "type": "function", + "function": { + "name": func_name, + "arguments": arguments + } + }) + return matches, start_pos + + +def _parse_deep_seek_tool_calls(answer: str, tool_names: list[str]): + """Parse DeepSeek-style tool calls using fullwidth Unicode token delimiters. + + Format: + <|tool▁calls▁begin|><|tool▁call▁begin|>func_name<|tool▁sep|>{"arg": "value"}<|tool▁call▁end|><|tool▁calls▁end|> + """ + matches = [] + start_pos = None + for m in re.finditer( + r'<|tool▁call▁begin|>\s*(\S+?)\s*<|tool▁sep|>\s*', + answer + ): + func_name = m.group(1).strip() + if func_name not in tool_names: + continue + json_str = _extract_balanced_json(answer, m.end()) + if json_str is None: + continue + try: + arguments = json.loads(json_str) + if start_pos is None: + # Check for section begin marker before the call marker + section = answer.rfind('<|tool▁calls▁begin|>', 0, m.start()) + start_pos = section if section != -1 else m.start() + matches.append({ + "type": "function", + "function": { + "name": func_name, + "arguments": arguments + } + }) + except json.JSONDecodeError: + pass + return matches, start_pos + + +def _parse_glm_tool_calls(answer: str, tool_names: list[str]): + """Parse GLM-style tool calls using arg_key/arg_value XML pairs. + + Format: + function_name + key1 + value1 + + """ + matches = [] + start_pos = None + for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL): + tc_content = tc_match.group(1) + # First non-tag text is the function name + name_match = re.match(r'([^<\s]+)', tc_content.strip()) + if not name_match: + continue + func_name = name_match.group(1).strip() + if func_name not in tool_names: + continue + # Extract arg_key/arg_value pairs + keys = [k.group(1).strip() for k in re.finditer(r'\s*(.*?)\s*', tc_content, re.DOTALL)] + vals = [v.group(1).strip() for v in re.finditer(r'\s*(.*?)\s*', tc_content, re.DOTALL)] + if len(keys) != len(vals): + continue + arguments = {} + for k, v in zip(keys, vals): + try: + v = json.loads(v) + except (json.JSONDecodeError, ValueError): + pass # keep as string + arguments[k] = v + if start_pos is None: + start_pos = tc_match.start() + matches.append({ + "type": "function", + "function": { + "name": func_name, + "arguments": arguments + } + }) + return matches, start_pos + + +def _parse_pythonic_tool_calls(answer: str, tool_names: list[str]): + """Parse pythonic-style tool calls used by Llama 4 and similar models. + + Format: + [func_name(param1="value1", param2="value2"), func_name2(...)] + """ + matches = [] + start_pos = None + # Match a bracketed list of function calls + bracket_match = re.search(r'\[([^\[\]]+)\]', answer) + if not bracket_match: + return matches, start_pos + + inner = bracket_match.group(1) + + # Build pattern for known tool names + escaped_names = [re.escape(name) for name in tool_names] + name_pattern = '|'.join(escaped_names) + + for call_match in re.finditer( + r'(' + name_pattern + r')\(([^)]*)\)', + inner + ): + func_name = call_match.group(1) + params_str = call_match.group(2).strip() + arguments = {} + + if params_str: + # Parse key="value" pairs, handling commas inside quoted values + for param_match in re.finditer( + r'(\w+)\s*=\s*("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\)]+)', + params_str + ): + param_name = param_match.group(1) + param_value = param_match.group(2).strip() + # Strip surrounding quotes + if (param_value.startswith('"') and param_value.endswith('"')) or \ + (param_value.startswith("'") and param_value.endswith("'")): + param_value = param_value[1:-1] + # Try to parse as JSON for numeric/bool/null values + try: + param_value = json.loads(param_value) + except (json.JSONDecodeError, ValueError): + pass + arguments[param_name] = param_value + + if start_pos is None: + start_pos = bracket_match.start() + matches.append({ + "type": "function", + "function": { + "name": func_name, + "arguments": arguments + } + }) + + return matches, start_pos + + +# Format registry: maps template substrings to the parser and streaming +# markers for that format. When a format's hints are NOT found in the +# template, its parser and markers are excluded. +TOOL_CALL_FORMATS = [ + { + 'template_hints': ['tool▁call▁begin', 'tool▁calls▁begin'], + 'parser': _parse_deep_seek_tool_calls, + 'markers': ['<|tool▁call▁begin|>', '<|tool▁calls▁begin|>'], + }, + { + 'template_hints': ['<|tool_call_begin|>', 'tool_calls_section'], + 'parser': _parse_kimi_tool_calls, + 'markers': ['<|tool_call_begin|>', '<|tool_calls_section_begin|>'], + }, + { + 'template_hints': ['to=functions.', '<|channel|>'], + 'parser': _parse_channel_tool_calls, + 'markers': ['to=functions.', '<|channel|>commentary'], + }, + { + 'template_hints': ['minimax:tool_call'], + 'parser': _parse_minimax_tool_calls, + 'markers': [''], + }, + { + 'template_hints': [''], + 'parser': _parse_glm_tool_calls, + 'markers': [''], + }, + { + 'template_hints': [''], + 'parser': _parse_xml_param_tool_calls, + 'markers': [''], + }, + { + 'template_hints': ['[TOOL_CALLS]'], + 'parser': _parse_mistral_token_tool_calls, + 'markers': ['[TOOL_CALLS]'], + }, + { + 'template_hints': [''], + 'parser': None, + 'markers': [''], + }, +] + +# Default ordered list of all specialized parsers. +ALL_PARSERS = [ + _parse_deep_seek_tool_calls, + _parse_kimi_tool_calls, + _parse_channel_tool_calls, + _parse_minimax_tool_calls, + _parse_glm_tool_calls, + _parse_xml_param_tool_calls, + _parse_mistral_token_tool_calls, + _parse_bare_name_tool_calls, + _parse_pythonic_tool_calls, +] + + +def detect_tool_call_format(template_str): + """Inspect a chat/instruction template to determine which tool call + formats are relevant. + + Uses an exclude-based approach: starts with all parsers/markers, + then removes the ones whose hints are not found in the template. + + Returns (parsers, streaming_markers, check_bare_names). + """ + if not template_str: + return None, TOOL_CALL_OPENING_MARKERS, True + + matched_any = False + exclude_parsers = [] + exclude_markers = [] + matched_markers = [] + + for fmt in TOOL_CALL_FORMATS: + if any(hint in template_str for hint in fmt['template_hints']): + matched_any = True + matched_markers.extend(fmt['markers']) + else: + if fmt['parser'] is not None: + exclude_parsers.append(fmt['parser']) + exclude_markers.extend(fmt['markers']) + + if not matched_any: + return None, TOOL_CALL_OPENING_MARKERS, True + + parsers = [p for p in ALL_PARSERS if p not in exclude_parsers] + markers = [m for m in TOOL_CALL_OPENING_MARKERS if m not in exclude_markers or m in matched_markers] + + return parsers, markers, False + + +def parse_tool_call(answer: str, tool_names: list[str], return_prefix: bool = False, parsers: list = None): + matches = [] + start_pos = None + + def _return(matches, start_pos): + if return_prefix: + prefix = answer[:start_pos] if matches and start_pos is not None else '' + return matches, prefix + return matches + + # Try specialized parsers. + for parser in (parsers if parsers is not None else ALL_PARSERS): + matches, start_pos = parser(answer, tool_names) + if matches: + return _return(matches, start_pos) + + # Generic fallback: regex pattern to find the JSON content wrapped in , , , and other tags observed from various models + patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)"] + + for pattern in patterns: + for match in re.finditer(pattern, answer, re.DOTALL): + if match.group(2) is None: + continue + # remove backtick wraps if present + candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip()) + candidate = re.sub(r"```$", "", candidate.strip()) + # unwrap inner tags + candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL) + # llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually + if re.search(r"\}\s*\n\s*\{", candidate) is not None: + candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate) + if not candidate.strip().startswith("["): + candidate = "[" + candidate + "]" + + candidates = [] + try: + # parse the candidate JSON into a dictionary + candidates = json.loads(candidate) + if not isinstance(candidates, list): + candidates = [candidates] + except json.JSONDecodeError: + # Ignore invalid JSON silently + continue + + for candidate_dict in candidates: + checked_candidate = check_and_sanitize_tool_call_candidate(candidate_dict, tool_names) + if checked_candidate is not None: + if start_pos is None: + start_pos = match.start() + matches.append(checked_candidate) + + # last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags + if len(matches) == 0: + try: + candidate = answer + # llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually + if re.search(r"\}\s*\n\s*\{", candidate) is not None: + candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate) + if not candidate.strip().startswith("["): + candidate = "[" + candidate + "]" + # parse the candidate JSON into a dictionary + candidates = json.loads(candidate) + if not isinstance(candidates, list): + candidates = [candidates] + for candidate_dict in candidates: + checked_candidate = check_and_sanitize_tool_call_candidate(candidate_dict, tool_names) + if checked_candidate is not None: + matches.append(checked_candidate) + except json.JSONDecodeError: + # Ignore invalid JSON silently + pass + + return _return(matches, start_pos) diff --git a/modules/tool_use.py b/modules/tool_use.py new file mode 100644 index 00000000..e22b1798 --- /dev/null +++ b/modules/tool_use.py @@ -0,0 +1,71 @@ +import importlib.util +import json + +from modules import shared +from modules.logging_colors import logger +from modules.utils import natural_keys, sanitize_filename + + +def get_available_tools(): + """Return sorted list of tool script names from user_data/tools/*.py.""" + tools_dir = shared.user_data_dir / 'tools' + tools_dir.mkdir(parents=True, exist_ok=True) + return sorted((p.stem for p in tools_dir.glob('*.py')), key=natural_keys) + + +def load_tools(selected_names): + """ + Import selected tool scripts and return their definitions and executors. + Returns (tool_defs, executors) where: + - tool_defs: list of OpenAI-format tool dicts + - executors: dict mapping function_name -> execute callable + """ + tool_defs = [] + executors = {} + for name in selected_names: + name = sanitize_filename(name) + if not name: + continue + + path = shared.user_data_dir / 'tools' / f'{name}.py' + if not path.exists(): + continue + + try: + spec = importlib.util.spec_from_file_location(f"tool_{name}", str(path)) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + except Exception: + logger.exception(f'Failed to load tool script "{name}"') + continue + + tool_def = getattr(module, 'tool', None) + execute_fn = getattr(module, 'execute', None) + if tool_def is None or execute_fn is None: + logger.warning(f'Tool "{name}" is missing a "tool" dict or "execute" function.') + continue + + func_name = tool_def.get('function', {}).get('name', name) + if func_name in executors: + logger.warning(f'Tool "{name}" declares function name "{func_name}" which conflicts with an already loaded tool. Skipping.') + continue + tool_defs.append(tool_def) + executors[func_name] = execute_fn + + return tool_defs, executors + + +def execute_tool(func_name, arguments, executors): + """Execute a tool by function name. Returns result as a JSON string.""" + fn = executors.get(func_name) + if fn is None: + return json.dumps({"error": f"Unknown tool: {func_name}"}) + + try: + if isinstance(arguments, str): + arguments = json.loads(arguments) + result = fn(arguments) + return json.dumps(result) if not isinstance(result, str) else result + except Exception as e: + logger.exception(f'Tool "{func_name}" execution failed') + return json.dumps({"error": str(e)}) diff --git a/modules/training.py b/modules/training.py index 2e172d22..878bb222 100644 --- a/modules/training.py +++ b/modules/training.py @@ -310,6 +310,11 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: # == Input validation / processing == yield "Preparing the input..." + + if shared.args.loader == 'llama.cpp': + yield "Error: LoRA training requires a model loaded with the Transformers loader. GGUF models are not supported for training." + return + lora_file_path = clean_path(None, lora_name) if lora_file_path.strip() == '': yield "Missing or invalid LoRA file name input." diff --git a/modules/transformers_loader.py b/modules/transformers_loader.py index d57020c6..63758ad7 100644 --- a/modules/transformers_loader.py +++ b/modules/transformers_loader.py @@ -65,14 +65,16 @@ class LogprobProcessor(LogitsProcessor): def __init__(self, logprobs=None): self.logprobs = logprobs self.token_alternatives = {} + self.token_alternatives_history = [] def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor: if self.logprobs is not None: # 0-5 log_e_probabilities = F.log_softmax(logits, dim=1) - top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1) + top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs) top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]] top_probs = [float(x) for x in top_values[0]] self.token_alternatives = dict(zip(top_tokens, top_probs)) + self.token_alternatives_history.append(self.token_alternatives) return logits @@ -134,8 +136,6 @@ def load_model_HF(model_name): shared.args.load_in_4bit, shared.args.disk, shared.args.cpu_memory is not None, - shared.args.compress_pos_emb > 1, - shared.args.alpha_value > 1, ]) # Load the model without any special settings @@ -198,11 +198,6 @@ def load_model_HF(model_name): if shared.args.disk: params['offload_folder'] = str(Path(shared.args.disk_cache_dir)) - if shared.args.compress_pos_emb > 1: - params['rope_scaling'] = {'type': 'linear', 'factor': shared.args.compress_pos_emb} - elif shared.args.alpha_value > 1: - params['rope_scaling'] = {'type': 'dynamic', 'factor': shared.args.alpha_value} - logger.info("TRANSFORMERS_PARAMS=") pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params) print() diff --git a/modules/ui.py b/modules/ui.py index 70e929f2..3f39a1a4 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -120,58 +120,8 @@ else: def list_model_elements(): - elements = [ - 'filter_by_loader', - 'loader', - 'cpu_memory', - 'gpu_layers', - 'fit_target', - 'cpu_moe', - 'threads', - 'threads_batch', - 'batch_size', - 'ubatch_size', - 'ctx_size', - 'cache_type', - 'tensor_split', - 'extra_flags', - 'streaming_llm', - 'gpu_split', - 'alpha_value', - 'rope_freq_base', - 'compress_pos_emb', - 'compute_dtype', - 'quant_type', - 'load_in_8bit', - 'load_in_4bit', - 'attn_implementation', - 'cpu', - 'disk', - 'row_split', - 'no_kv_offload', - 'no_mmap', - 'mlock', - 'numa', - 'parallel', - 'use_double_quant', - 'bf16', - 'enable_tp', - 'tp_backend', - 'cfg_cache', - 'no_use_fast', - 'model_draft', - 'draft_max', - 'gpu_layers_draft', - 'device_draft', - 'ctx_size_draft', - 'spec_type', - 'spec_ngram_size_n', - 'spec_ngram_size_m', - 'spec_ngram_min_hits', - 'mmproj', - ] - - return elements + from modules.loaders import list_model_elements + return list_model_elements() def list_interface_input_elements(): @@ -249,6 +199,7 @@ def list_interface_input_elements(): 'unique_id', 'textbox', 'start_with', + 'selected_tools', 'mode', 'chat_style', 'chat-instruct_command', @@ -353,12 +304,16 @@ def save_settings(state, preset, extensions_list, show_controls, theme_state, ma if k in shared.settings and k not in exclude: output[k] = state[k] - output['preset'] = preset + if preset: + output['preset'] = preset output['prompt-notebook'] = state['prompt_menu-default'] if state['show_two_notebook_columns'] else state['prompt_menu-notebook'] - output['character'] = state['character_menu'] - if 'user_menu' in state and state['user_menu']: + if state.get('character_menu'): + output['character'] = state['character_menu'] + if state.get('user_menu'): output['user'] = state['user_menu'] output['seed'] = int(output['seed']) + output['custom_stopping_strings'] = output.get('custom_stopping_strings') or '' + output['custom_token_bans'] = output.get('custom_token_bans') or '' output['show_controls'] = show_controls output['dark_theme'] = True if theme_state == 'dark' else False output.pop('instruction_template_str') @@ -470,6 +425,7 @@ def setup_auto_save(): 'user_bio', 'custom_system_message', 'chat_template_str', + 'selected_tools', # Parameters tab (ui_parameters.py) - Generation parameters 'preset_menu', @@ -520,7 +476,6 @@ def setup_auto_save(): 'skip_special_tokens', 'stream', 'static_cache', - 'truncation_length', 'seed', 'sampler_priority', 'custom_stopping_strings', diff --git a/modules/ui_chat.py b/modules/ui_chat.py index 74da0a40..d2a515b8 100644 --- a/modules/ui_chat.py +++ b/modules/ui_chat.py @@ -28,7 +28,8 @@ def create_ui(): shared.gradio['branch_chat'] = gr.Button('Branch', elem_classes=['refresh-button', 'refresh-button-medium'], elem_id='Branch', interactive=not mu) shared.gradio['rename_chat'] = gr.Button('Rename', elem_classes=['refresh-button', 'refresh-button-medium'], interactive=not mu) shared.gradio['delete_chat'] = gr.Button('🗑️', visible=False, elem_classes='refresh-button', interactive=not mu, elem_id='delete_chat') - shared.gradio['Start new chat'] = gr.Button('New chat', elem_classes=['refresh-button', 'refresh-button-medium', 'focus-on-chat-input']) + shared.gradio['Start new chat'] = gr.Button('New chat', elem_classes=['refresh-button', 'refresh-button-medium', 'focus-on-chat-input'], elem_id='new-chat-btn') + shared.gradio['Start incognito chat'] = gr.Button('Incognito chat', visible=False, elem_id='incognito-chat-btn') shared.gradio['branch_index'] = gr.Number(value=-1, precision=0, visible=False, elem_id="Branch-index", interactive=True) shared.gradio['search_chat'] = gr.Textbox(placeholder='Search chats...', max_lines=1, elem_id='search_chat') @@ -91,6 +92,21 @@ def create_ui(): gr.HTML("") + from modules.tool_use import get_available_tools + shared.gradio['selected_tools'] = gr.CheckboxGroup(choices=get_available_tools(), value=shared.settings.get('selected_tools', []), label='Tools', info='Functions the model can call during generation.', elem_id='tools-group') + shared.gradio['tools_refresh'] = gr.Button('Refresh list', elem_id='tools-refresh-btn', visible=False) + shared.gradio['tools_refresh'].click(fn=lambda: gr.update(choices=get_available_tools()), inputs=[], outputs=[shared.gradio['selected_tools']]) + + def sync_web_tools(selected): + if 'web_search' in selected and 'fetch_webpage' not in selected and 'fetch_webpage' in get_available_tools(): + selected.append('fetch_webpage') + + return gr.update(value=selected) + + shared.gradio['selected_tools'].change(fn=sync_web_tools, inputs=[shared.gradio['selected_tools']], outputs=[shared.gradio['selected_tools']], show_progress=False) + + gr.HTML("") + with gr.Row(): shared.gradio['mode'] = gr.Radio(choices=['instruct', 'chat-instruct', 'chat'], value=None, label='Mode', info='In instruct and chat-instruct modes, the template under Parameters > Instruction template is used.', elem_id='chat-mode') @@ -275,6 +291,10 @@ def create_event_handlers(): ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_start_new_chat_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False) + shared.gradio['Start incognito chat'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + chat.handle_start_incognito_chat_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False) + shared.gradio['delete_chat-confirm'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_delete_chat_confirm_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False) diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py index e9df9bd3..dc108f6d 100644 --- a/modules/ui_image_generation.py +++ b/modules/ui_image_generation.py @@ -728,6 +728,8 @@ def generate_prompt_variation(state): variation = variation.rsplit("", 1)[1] elif "<|start|>assistant<|channel|>final<|message|>" in variation: variation = variation.rsplit("<|start|>assistant<|channel|>final<|message|>", 1)[1] + elif "<|channel|>final<|message|>" in variation: + variation = variation.rsplit("<|channel|>final<|message|>", 1)[1] elif "" in variation: variation = variation.rsplit("", 1)[1] diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 7e91f1ce..5cf0155d 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -42,11 +42,11 @@ def create_ui(): with gr.Row(): with gr.Column(): shared.gradio['gpu_layers'] = gr.Slider(label="gpu-layers", minimum=-1, maximum=get_initial_gpu_layers_max(), step=1, value=shared.args.gpu_layers, info='Number of layers to offload to the GPU. -1 = auto.') - shared.gradio['ctx_size'] = gr.Slider(label='ctx-size', minimum=0, maximum=131072, step=256, value=shared.args.ctx_size, info='Context length. llama.cpp: 0 = auto if gpu-layers is also -1. Common values: 4096, 8192, 16384, 32768, 65536, 131072.') + shared.gradio['ctx_size'] = gr.Slider(label='ctx-size', minimum=0, maximum=1048576, step=1024, value=shared.args.ctx_size, info='Context length. 0 = auto for llama.cpp (requires gpu-layers=-1), 8192 for other loaders. Common values: 4096, 8192, 16384, 32768, 65536, 131072.') shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7') shared.gradio['attn_implementation'] = gr.Dropdown(label="attn-implementation", choices=['sdpa', 'eager', 'flash_attention_2'], value=shared.args.attn_implementation, info='Attention implementation.') shared.gradio['cache_type'] = gr.Dropdown(label="cache-type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q7', 'q6', 'q5', 'q4', 'q3', 'q2'], value=shared.args.cache_type, allow_custom_value=True, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8. For ExLlamaV3, you can type custom combinations for separate k/v bits (e.g. q4_q8).') - shared.gradio['fit_target'] = gr.Textbox(label='fit-target', value=shared.args.fit_target, info='Target VRAM margin per device for auto GPU layers (MiB). Comma-separated list for multiple devices. Default: 1024.') + shared.gradio['fit_target'] = gr.Textbox(label='fit-target', value=shared.args.fit_target, info='Target VRAM margin per device for auto GPU layers (MiB). Comma-separated list for multiple devices.') shared.gradio['tp_backend'] = gr.Dropdown(label="tp-backend", choices=['native', 'nccl'], value=shared.args.tp_backend, info='The backend for tensor parallelism.') with gr.Column(): @@ -100,9 +100,6 @@ def create_ui(): shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40') shared.gradio['extra_flags'] = gr.Textbox(label='extra-flags', info='Additional flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"', value=shared.args.extra_flags) shared.gradio['cpu_memory'] = gr.Number(label="Maximum CPU memory in GiB. Use this for CPU offloading.", value=shared.args.cpu_memory) - shared.gradio['alpha_value'] = gr.Number(label='alpha_value', value=shared.args.alpha_value, precision=2, info='Positional embeddings alpha factor for NTK RoPE scaling. Recommended values (NTKv1): 1.75 for 1.5x context, 2.5 for 2x context. Use either this or compress_pos_emb, not both.') - shared.gradio['rope_freq_base'] = gr.Number(label='rope_freq_base', value=shared.args.rope_freq_base, precision=0, info='Positional embeddings frequency base for NTK RoPE scaling. Related to alpha_value by rope_freq_base = 10000 * alpha_value ^ (64 / 63). 0 = from model.') - shared.gradio['compress_pos_emb'] = gr.Number(label='compress_pos_emb', value=shared.args.compress_pos_emb, precision=2, info='Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.') shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype, info='Used by load-in-4bit.') shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type, info='Used by load-in-4bit.') @@ -388,7 +385,11 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur def update_truncation_length(current_length, state): if 'loader' in state: if state['loader'].lower().startswith('exllama') or state['loader'] == 'llama.cpp': - return state['ctx_size'] + if state['ctx_size'] > 0: + return state['ctx_size'] + + # ctx_size == 0 means auto: use the actual value from the server + return shared.settings['truncation_length'] return current_length diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index e5eb9210..5411b294 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -37,10 +37,10 @@ def create_ui(): shared.gradio['dynamic_temperature'] = gr.Checkbox(value=shared.settings['dynamic_temperature'], label='dynamic_temperature') gr.Markdown('## Curve cutoff') - shared.gradio['min_p'] = gr.Slider(0.0, 1.0, value=shared.settings['min_p'], step=0.01, label='min_p') - shared.gradio['top_n_sigma'] = gr.Slider(0.0, 5.0, value=shared.settings['top_n_sigma'], step=0.01, label='top_n_sigma') shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=shared.settings['top_p'], step=0.01, label='top_p') shared.gradio['top_k'] = gr.Slider(0, 200, value=shared.settings['top_k'], step=1, label='top_k') + shared.gradio['min_p'] = gr.Slider(0.0, 1.0, value=shared.settings['min_p'], step=0.01, label='min_p') + shared.gradio['top_n_sigma'] = gr.Slider(0.0, 5.0, value=shared.settings['top_n_sigma'], step=0.01, label='top_n_sigma') shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=shared.settings['typical_p'], step=0.01, label='typical_p') shared.gradio['xtc_threshold'] = gr.Slider(0, 0.5, value=shared.settings['xtc_threshold'], step=0.01, label='xtc_threshold', info='If 2 or more tokens have probability above this threshold, consider removing all but the last one.') shared.gradio['xtc_probability'] = gr.Slider(0, 1, value=shared.settings['xtc_probability'], step=0.01, label='xtc_probability', info='Probability that the removal will actually happen. 0 disables the sampler. 1 makes it always happen.') @@ -73,7 +73,7 @@ def create_ui(): gr.Markdown('## Other options') shared.gradio['do_sample'] = gr.Checkbox(value=shared.settings['do_sample'], label='do_sample') shared.gradio['temperature_last'] = gr.Checkbox(value=shared.settings['temperature_last'], label='temperature_last', info='Moves temperature/dynamic temperature/quadratic sampling to the end of the sampler stack, ignoring their positions in "Sampler priority".') - shared.gradio['sampler_priority'] = gr.Textbox(value=shared.settings['sampler_priority'], lines=10, label='Sampler priority', info='Parameter names separated by new lines or commas.', elem_classes=['add_scrollbar']) + shared.gradio['sampler_priority'] = gr.DragDrop(value=shared.settings['sampler_priority'], label='Sampler priority', info='Parameter names separated by new lines or commas.', elem_classes=['add_scrollbar']) shared.gradio['dry_sequence_breakers'] = gr.Textbox(value=shared.settings['dry_sequence_breakers'], label='dry_sequence_breakers', info='Tokens across which sequence matching is not continued. Specified as a comma-separated list of quoted strings.') with gr.Column(): diff --git a/modules/web_search.py b/modules/web_search.py index 597af4b2..e13ef62a 100644 --- a/modules/web_search.py +++ b/modules/web_search.py @@ -1,11 +1,12 @@ import concurrent.futures import html +import ipaddress import random import re -import urllib.request +import socket from concurrent.futures import as_completed from datetime import datetime -from urllib.parse import quote_plus +from urllib.parse import parse_qs, quote_plus, urljoin, urlparse import requests @@ -13,34 +14,60 @@ from modules import shared from modules.logging_colors import logger +def _validate_url(url): + """Validate that a URL is safe to fetch (not targeting private/internal networks).""" + parsed = urlparse(url) + if parsed.scheme not in ('http', 'https'): + raise ValueError(f"Unsupported URL scheme: {parsed.scheme}") + + hostname = parsed.hostname + if not hostname: + raise ValueError("No hostname in URL") + + # Resolve hostname and check all returned addresses + try: + for family, _, _, _, sockaddr in socket.getaddrinfo(hostname, None): + ip = ipaddress.ip_address(sockaddr[0]) + if not ip.is_global: + raise ValueError(f"Access to non-public address {ip} is blocked") + except socket.gaierror: + raise ValueError(f"Could not resolve hostname: {hostname}") + + def get_current_timestamp(): """Returns the current time in 24-hour format""" return datetime.now().strftime('%b %d, %Y %H:%M') -def download_web_page(url, timeout=10): +def download_web_page(url, timeout=10, include_links=False): """ - Download a web page and convert its HTML content to structured Markdown text. + Download a web page and extract its main content as Markdown text. """ - import html2text + import trafilatura try: + _validate_url(url) headers = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' } - response = requests.get(url, headers=headers, timeout=timeout) - response.raise_for_status() # Raise an exception for bad status codes + max_redirects = 5 + for _ in range(max_redirects): + response = requests.get(url, headers=headers, timeout=timeout, allow_redirects=False) + if response.is_redirect and 'Location' in response.headers: + url = urljoin(url, response.headers['Location']) + _validate_url(url) + else: + break - # Initialize the HTML to Markdown converter - h = html2text.HTML2Text() - h.body_width = 0 - h.ignore_images = True - h.ignore_links = True + response.raise_for_status() - # Convert the HTML to Markdown - markdown_text = h.handle(response.text) - - return markdown_text + result = trafilatura.extract( + response.text, + include_links=include_links, + output_format='markdown', + url=url + ) + return result or "" except requests.exceptions.RequestException as e: logger.error(f"Error downloading {url}: {e}") return "" @@ -49,8 +76,8 @@ def download_web_page(url, timeout=10): return "" -def perform_web_search(query, num_pages=3, max_workers=5, timeout=10): - """Perform web search and return results with content""" +def perform_web_search(query, num_pages=3, max_workers=5, timeout=10, fetch_content=True): + """Perform web search and return results, optionally with page content""" try: search_url = f"https://html.duckduckgo.com/html/?q={quote_plus(query)}" @@ -59,25 +86,41 @@ def perform_web_search(query, num_pages=3, max_workers=5, timeout=10): "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36" ] - response_text = "" - req = urllib.request.Request(search_url, headers={'User-Agent': random.choice(agents)}) - with urllib.request.urlopen(req, timeout=timeout) as response: - response_text = response.read().decode('utf-8') + response = requests.get(search_url, headers={'User-Agent': random.choice(agents)}, timeout=timeout) + response.raise_for_status() + response_text = response.text - # Extract results with regex - titles = re.findall(r']*class="[^"]*result__a[^"]*"[^>]*>(.*?)', response_text, re.DOTALL) - urls = re.findall(r']*class="[^"]*result__url[^"]*"[^>]*>(.*?)', response_text, re.DOTALL) + # Extract results - title and URL come from the same element + result_links = re.findall(r']*class="[^"]*result__a[^"]*"[^>]*>(.*?)', response_text, re.DOTALL) + result_tags = re.findall(r']*class="[^"]*result__a[^"]*"[^>]*)>', response_text, re.DOTALL) # Prepare download tasks download_tasks = [] - for i in range(min(len(titles), len(urls), num_pages)): - url = f"https://{urls[i].strip()}" - title = re.sub(r'<[^>]+>', '', titles[i]).strip() - title = html.unescape(title) - download_tasks.append((url, title, i)) + for i, (tag_attrs, raw_title) in enumerate(zip(result_tags, result_links)): + if num_pages is not None and i >= num_pages: + break + # Extract href and resolve the actual URL from DuckDuckGo's redirect link + href_match = re.search(r'href="([^"]*)"', tag_attrs) + if not href_match: + continue + uddg = parse_qs(urlparse(html.unescape(href_match.group(1))).query).get('uddg', [''])[0] + if not uddg: + continue + title = html.unescape(re.sub(r'<[^>]+>', '', raw_title).strip()) + download_tasks.append((uddg, title, len(download_tasks))) search_results = [None] * len(download_tasks) # Pre-allocate to maintain order + if not fetch_content: + for url, title, index in download_tasks: + search_results[index] = { + 'title': title, + 'url': url, + 'content': '' + } + + return search_results + # Download pages in parallel with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit all download tasks diff --git a/one_click.py b/one_click.py index 5131206e..d6ba9039 100644 --- a/one_click.py +++ b/one_click.py @@ -91,7 +91,7 @@ def get_gpu_choice(): "What is your GPU?", { 'A': 'NVIDIA', - 'B': 'AMD - Linux/macOS only, requires ROCm 6.4', + 'B': 'AMD - Linux only, ROCm 7.2', 'C': 'Apple M Series', 'D': 'Intel Arc (beta)', 'N': 'CPU mode' @@ -116,14 +116,12 @@ def get_pytorch_install_command(gpu_choice): if gpu_choice == "NVIDIA_CUDA128": return base_cmd + "--index-url https://download.pytorch.org/whl/cu128" + pypi_fallback elif gpu_choice == "AMD": - return base_cmd + "--index-url https://download.pytorch.org/whl/rocm6.4" + pypi_fallback + py_tag = f"cp{PYTHON_VERSION.replace('.', '')}" + return f"python -m pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-{TORCH_VERSION}%2Brocm7.2.0.lw.git7e1940d4-{py_tag}-{py_tag}-linux_x86_64.whl" elif gpu_choice in ["APPLE", "NONE"]: return base_cmd + "--index-url https://download.pytorch.org/whl/cpu" + pypi_fallback elif gpu_choice == "INTEL": - if is_linux(): - return "python -m pip install torch==2.1.0a0 intel-extension-for-pytorch==2.1.10+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/" - else: - return "python -m pip install torch==2.1.0a0 intel-extension-for-pytorch==2.1.10 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/" + return base_cmd + "--index-url https://download.pytorch.org/whl/xpu" else: return base_cmd @@ -136,12 +134,12 @@ def get_pytorch_update_command(gpu_choice): if gpu_choice == "NVIDIA_CUDA128": return f"{base_cmd}--index-url https://download.pytorch.org/whl/cu128" + pypi_fallback elif gpu_choice == "AMD": - return f"{base_cmd}--index-url https://download.pytorch.org/whl/rocm6.4" + pypi_fallback + py_tag = f"cp{PYTHON_VERSION.replace('.', '')}" + return f"python -m pip install --upgrade https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-{TORCH_VERSION}%2Brocm7.2.0.lw.git7e1940d4-{py_tag}-{py_tag}-linux_x86_64.whl" elif gpu_choice in ["APPLE", "NONE"]: return f"{base_cmd}--index-url https://download.pytorch.org/whl/cpu" + pypi_fallback elif gpu_choice == "INTEL": - intel_extension = "intel-extension-for-pytorch==2.1.10+xpu" if is_linux() else "intel-extension-for-pytorch==2.1.10" - return f"{base_cmd}{intel_extension} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/" + return f"{base_cmd}--index-url https://download.pytorch.org/whl/xpu" else: return base_cmd @@ -196,6 +194,8 @@ def run_cmd(cmd, assert_success=False, environment=False, capture_output=False, if environment: if is_windows(): conda_bat_path = os.path.join(script_dir, "installer_files", "conda", "condabin", "conda.bat") + python_path = os.path.join(conda_env_path, "python.exe") + cmd = cmd.replace("python ", f'"{python_path}" ') cmd = f'"{conda_bat_path}" activate "{conda_env_path}" >nul && {cmd}' else: conda_sh_path = os.path.join(script_dir, "installer_files", "conda", "etc", "profile.d", "conda.sh") @@ -270,7 +270,7 @@ def update_pytorch_and_python(): def clean_outdated_pytorch_cuda_dependencies(): - patterns = ["cu121", "cu122", "torch2.4", "torch2.6", "torch2.7", "torchvision", "torchaudio"] + patterns = ["cu121", "cu122", "rocm6", "torch2.4", "torch2.6", "torch2.7", "torchvision", "torchaudio"] result = run_cmd("python -m pip list --format=freeze", capture_output=True, environment=True) matching_packages = [] @@ -316,13 +316,6 @@ def install_webui(): install_pytorch = get_pytorch_install_command(gpu_choice) run_cmd(f"conda install -y ninja git && {install_pytorch}", assert_success=True, environment=True) - if gpu_choice == "INTEL": - # Install oneAPI dependencies via conda - print_big_message("Installing Intel oneAPI runtime libraries.") - run_cmd("conda install -y -c https://software.repos.intel.com/python/conda/ -c conda-forge dpcpp-cpp-rt=2024.0 mkl-dpcpp=2024.0", environment=True) - # Install libuv required by Intel-patched torch - run_cmd("conda install -y libuv", environment=True) - # Install the webui requirements update_requirements(initial_installation=True, pull=False) @@ -365,8 +358,10 @@ def update_requirements(initial_installation=False, pull=True): current_commit = get_current_commit() wheels_changed = not os.path.exists(state_file) + installed_wheels = set() if not wheels_changed: state = load_state() + installed_wheels = set(state.get('installed_wheels', [])) if 'wheels_changed' in state or state.get('last_installed_commit') != current_commit: wheels_changed = True @@ -431,9 +426,17 @@ def update_requirements(initial_installation=False, pull=True): # Prepare the requirements file textgen_requirements = open(requirements_file).read().splitlines() + all_whl_lines = [line.strip() for line in textgen_requirements if '.whl' in line] - if not initial_installation and not wheels_changed: - textgen_requirements = [line for line in textgen_requirements if '.whl' not in line] + if not initial_installation: + if installed_wheels: + # Per-wheel comparison: only re-download wheels that changed + textgen_requirements = [ + line for line in textgen_requirements + if '.whl' not in line or line.strip() not in installed_wheels + ] + elif not wheels_changed: + textgen_requirements = [line for line in textgen_requirements if '.whl' not in line] with open('temp_requirements.txt', 'w') as file: file.write('\n'.join(textgen_requirements)) @@ -452,6 +455,7 @@ def update_requirements(initial_installation=False, pull=True): # Save state after successful installation state = load_state() state['last_installed_commit'] = current_commit + state['installed_wheels'] = all_whl_lines state.pop('wheels_changed', None) save_state(state) diff --git a/requirements/full/requirements.txt b/requirements/full/requirements.txt index eaf34fa8..c24f4a9d 100644 --- a/requirements/full/requirements.txt +++ b/requirements/full/requirements.txt @@ -2,11 +2,10 @@ accelerate==1.12.* audioop-lts<1.0; python_version >= "3.13" bitsandbytes==0.49.* datasets -diffusers==0.36.* +diffusers==0.37.* einops fastapi==0.112.4 flash-linear-attention==0.4.* -html2text==2025.4.15 huggingface-hub==1.5.* jinja2==3.1.6 markdown @@ -25,14 +24,15 @@ scipy sentencepiece tensorboard torchao==0.15.* +trafilatura==2.0.0 transformers==5.3.* triton-windows==3.5.1.post24; platform_system == "Windows" tqdm wandb # Gradio -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 @@ -40,9 +40,9 @@ sse-starlette==1.6.5 tiktoken # CUDA wheels -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" -https://github.com/turboderp-org/exllamav3/releases/download/v0.0.23/exllamav3-0.0.23+cu128.torch2.9.0-cp313-cp313-win_amd64.whl; platform_system == "Windows" and python_version == "3.13" -https://github.com/turboderp-org/exllamav3/releases/download/v0.0.23/exllamav3-0.0.23+cu128.torch2.9.0-cp313-cp313-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.13" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" +https://github.com/turboderp-org/exllamav3/releases/download/v0.0.25/exllamav3-0.0.25+cu128.torch2.9.0-cp313-cp313-win_amd64.whl; platform_system == "Windows" and python_version == "3.13" +https://github.com/turboderp-org/exllamav3/releases/download/v0.0.25/exllamav3-0.0.25+cu128.torch2.9.0-cp313-cp313-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.13" https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp313-cp313-win_amd64.whl; platform_system == "Windows" and python_version == "3.13" https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp313-cp313-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.13" diff --git a/requirements/full/requirements_amd.txt b/requirements/full/requirements_amd.txt index 3211f251..7c481224 100644 --- a/requirements/full/requirements_amd.txt +++ b/requirements/full/requirements_amd.txt @@ -1,10 +1,9 @@ accelerate==1.12.* audioop-lts<1.0; python_version >= "3.13" datasets -diffusers==0.36.* +diffusers==0.37.* einops fastapi==0.112.4 -html2text==2025.4.15 huggingface-hub==1.5.* jinja2==3.1.6 markdown @@ -25,11 +24,12 @@ tensorboard torchao==0.15.* transformers==5.3.* tqdm +trafilatura==2.0.0 wandb # Gradio -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 @@ -37,5 +37,5 @@ sse-starlette==1.6.5 tiktoken # AMD wheels -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+rocm6.4-py3-none-win_amd64.whl; platform_system == "Windows" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+rocm6.4-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+rocm7.2-py3-none-win_amd64.whl; platform_system == "Windows" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+rocm7.2-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" diff --git a/requirements/full/requirements_apple_intel.txt b/requirements/full/requirements_apple_intel.txt index 8d452114..b1c8f78e 100644 --- a/requirements/full/requirements_apple_intel.txt +++ b/requirements/full/requirements_apple_intel.txt @@ -1,10 +1,9 @@ accelerate==1.12.* audioop-lts<1.0; python_version >= "3.13" datasets -diffusers==0.36.* +diffusers==0.37.* einops fastapi==0.112.4 -html2text==2025.4.15 huggingface-hub==1.5.* jinja2==3.1.6 markdown @@ -25,11 +24,12 @@ tensorboard torchao==0.15.* transformers==5.3.* tqdm +trafilatura==2.0.0 wandb # Gradio -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 @@ -37,4 +37,4 @@ sse-starlette==1.6.5 tiktoken # Mac wheels -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0-py3-none-macosx_13_0_x86_64.whl; platform_system == "Darwin" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0-py3-none-macosx_13_0_x86_64.whl; platform_system == "Darwin" diff --git a/requirements/full/requirements_apple_silicon.txt b/requirements/full/requirements_apple_silicon.txt index 525ceed5..63ef33ea 100644 --- a/requirements/full/requirements_apple_silicon.txt +++ b/requirements/full/requirements_apple_silicon.txt @@ -1,10 +1,9 @@ accelerate==1.12.* audioop-lts<1.0; python_version >= "3.13" datasets -diffusers==0.36.* +diffusers==0.37.* einops fastapi==0.112.4 -html2text==2025.4.15 huggingface-hub==1.5.* jinja2==3.1.6 markdown @@ -25,11 +24,12 @@ tensorboard torchao==0.15.* transformers==5.3.* tqdm +trafilatura==2.0.0 wandb # Gradio -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 @@ -37,4 +37,4 @@ sse-starlette==1.6.5 tiktoken # Mac wheels -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0-py3-none-macosx_13_0_arm64.whl; platform_system == "Darwin" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0-py3-none-macosx_13_0_arm64.whl; platform_system == "Darwin" diff --git a/requirements/full/requirements_cpu_only.txt b/requirements/full/requirements_cpu_only.txt index 86b65a97..4bc61622 100644 --- a/requirements/full/requirements_cpu_only.txt +++ b/requirements/full/requirements_cpu_only.txt @@ -1,10 +1,9 @@ accelerate==1.12.* audioop-lts<1.0; python_version >= "3.13" datasets -diffusers==0.36.* +diffusers==0.37.* einops fastapi==0.112.4 -html2text==2025.4.15 huggingface-hub==1.5.* jinja2==3.1.6 markdown @@ -25,11 +24,12 @@ tensorboard torchao==0.15.* transformers==5.3.* tqdm +trafilatura==2.0.0 wandb # Gradio -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 @@ -37,5 +37,5 @@ sse-starlette==1.6.5 tiktoken # llama.cpp (CPU only) -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cpu-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cpu-py3-none-win_amd64.whl; platform_system == "Windows" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cpu-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cpu-py3-none-win_amd64.whl; platform_system == "Windows" diff --git a/requirements/full/requirements_nowheels.txt b/requirements/full/requirements_nowheels.txt index 0a924d31..2ec1e61e 100644 --- a/requirements/full/requirements_nowheels.txt +++ b/requirements/full/requirements_nowheels.txt @@ -1,10 +1,9 @@ accelerate==1.12.* audioop-lts<1.0; python_version >= "3.13" datasets -diffusers==0.36.* +diffusers==0.37.* einops fastapi==0.112.4 -html2text==2025.4.15 huggingface-hub==1.5.* jinja2==3.1.6 markdown @@ -25,11 +24,12 @@ tensorboard torchao==0.15.* transformers==5.3.* tqdm +trafilatura==2.0.0 wandb # Gradio -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 diff --git a/requirements/portable/requirements.txt b/requirements/portable/requirements.txt index 61c9ef73..ba4c7a04 100644 --- a/requirements/portable/requirements.txt +++ b/requirements/portable/requirements.txt @@ -1,6 +1,5 @@ audioop-lts<1.0; python_version >= "3.13" fastapi==0.112.4 -html2text==2025.4.15 huggingface-hub==1.5.* jinja2==3.1.6 markdown @@ -11,11 +10,12 @@ python-docx==1.1.2 pyyaml requests rich +trafilatura==2.0.0 tqdm # Gradio -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 @@ -23,5 +23,5 @@ sse-starlette==1.6.5 tiktoken # CUDA wheels -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" diff --git a/requirements/portable/requirements_amd.txt b/requirements/portable/requirements_amd.txt index 3d0785a3..5dfdd9c8 100644 --- a/requirements/portable/requirements_amd.txt +++ b/requirements/portable/requirements_amd.txt @@ -1,6 +1,5 @@ audioop-lts<1.0; python_version >= "3.13" fastapi==0.112.4 -html2text==2025.4.15 huggingface-hub==1.5.* jinja2==3.1.6 markdown @@ -11,11 +10,12 @@ python-docx==1.1.2 pyyaml requests rich +trafilatura==2.0.0 tqdm # Gradio -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 @@ -23,5 +23,5 @@ sse-starlette==1.6.5 tiktoken # AMD wheels -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+rocm6.4-py3-none-win_amd64.whl; platform_system == "Windows" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+rocm6.4-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+rocm7.2-py3-none-win_amd64.whl; platform_system == "Windows" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+rocm7.2-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" diff --git a/requirements/portable/requirements_apple_intel.txt b/requirements/portable/requirements_apple_intel.txt index 6805e209..f62241b3 100644 --- a/requirements/portable/requirements_apple_intel.txt +++ b/requirements/portable/requirements_apple_intel.txt @@ -1,6 +1,5 @@ audioop-lts<1.0; python_version >= "3.13" fastapi==0.112.4 -html2text==2025.4.15 huggingface-hub==1.5.* jinja2==3.1.6 markdown @@ -11,11 +10,12 @@ python-docx==1.1.2 pyyaml requests rich +trafilatura==2.0.0 tqdm # Gradio -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 @@ -23,4 +23,4 @@ sse-starlette==1.6.5 tiktoken # Mac wheels -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0-py3-none-macosx_13_0_x86_64.whl; platform_system == "Darwin" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0-py3-none-macosx_13_0_x86_64.whl; platform_system == "Darwin" diff --git a/requirements/portable/requirements_apple_silicon.txt b/requirements/portable/requirements_apple_silicon.txt index 5a8ed87b..353d9172 100644 --- a/requirements/portable/requirements_apple_silicon.txt +++ b/requirements/portable/requirements_apple_silicon.txt @@ -1,6 +1,5 @@ audioop-lts<1.0; python_version >= "3.13" fastapi==0.112.4 -html2text==2025.4.15 huggingface-hub==1.5.* jinja2==3.1.6 markdown @@ -11,11 +10,12 @@ python-docx==1.1.2 pyyaml requests rich +trafilatura==2.0.0 tqdm # Gradio -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 @@ -23,4 +23,4 @@ sse-starlette==1.6.5 tiktoken # Mac wheels -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0-py3-none-macosx_13_0_arm64.whl; platform_system == "Darwin" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0-py3-none-macosx_13_0_arm64.whl; platform_system == "Darwin" diff --git a/requirements/portable/requirements_cpu_only.txt b/requirements/portable/requirements_cpu_only.txt index fafa23cf..5f039318 100644 --- a/requirements/portable/requirements_cpu_only.txt +++ b/requirements/portable/requirements_cpu_only.txt @@ -1,6 +1,5 @@ audioop-lts<1.0; python_version >= "3.13" fastapi==0.112.4 -html2text==2025.4.15 huggingface-hub==1.5.* jinja2==3.1.6 markdown @@ -11,11 +10,12 @@ python-docx==1.1.2 pyyaml requests rich +trafilatura==2.0.0 tqdm # Gradio -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 @@ -23,5 +23,5 @@ sse-starlette==1.6.5 tiktoken # llama.cpp (CPU only) -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cpu-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cpu-py3-none-win_amd64.whl; platform_system == "Windows" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cpu-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cpu-py3-none-win_amd64.whl; platform_system == "Windows" diff --git a/requirements/portable/requirements_cuda131.txt b/requirements/portable/requirements_cuda131.txt index 3ef59f97..d8b03102 100644 --- a/requirements/portable/requirements_cuda131.txt +++ b/requirements/portable/requirements_cuda131.txt @@ -1,6 +1,5 @@ audioop-lts<1.0; python_version >= "3.13" fastapi==0.112.4 -html2text==2025.4.15 huggingface-hub==1.5.* jinja2==3.1.6 markdown @@ -11,11 +10,12 @@ python-docx==1.1.2 pyyaml requests rich +trafilatura==2.0.0 tqdm # Gradio -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 @@ -23,5 +23,5 @@ sse-starlette==1.6.5 tiktoken # CUDA wheels -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu131-py3-none-win_amd64.whl; platform_system == "Windows" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu131-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu131-py3-none-win_amd64.whl; platform_system == "Windows" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu131-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" diff --git a/requirements/portable/requirements_nowheels.txt b/requirements/portable/requirements_nowheels.txt index c2fc33eb..4b548dae 100644 --- a/requirements/portable/requirements_nowheels.txt +++ b/requirements/portable/requirements_nowheels.txt @@ -1,6 +1,5 @@ audioop-lts<1.0; python_version >= "3.13" fastapi==0.112.4 -html2text==2025.4.15 huggingface-hub==1.5.* jinja2==3.1.6 markdown @@ -11,11 +10,12 @@ python-docx==1.1.2 pyyaml requests rich +trafilatura==2.0.0 tqdm # Gradio -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 diff --git a/requirements/portable/requirements_vulkan.txt b/requirements/portable/requirements_vulkan.txt index 6039357d..fd2511f4 100644 --- a/requirements/portable/requirements_vulkan.txt +++ b/requirements/portable/requirements_vulkan.txt @@ -1,6 +1,5 @@ audioop-lts<1.0; python_version >= "3.13" fastapi==0.112.4 -html2text==2025.4.15 huggingface-hub==1.5.* jinja2==3.1.6 markdown @@ -11,11 +10,12 @@ python-docx==1.1.2 pyyaml requests rich +trafilatura==2.0.0 tqdm # Gradio -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl -https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl +https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 @@ -23,5 +23,5 @@ sse-starlette==1.6.5 tiktoken # Vulkan wheels -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+vulkan-py3-none-win_amd64.whl; platform_system == "Windows" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+vulkan-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+vulkan-py3-none-win_amd64.whl; platform_system == "Windows" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+vulkan-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" diff --git a/server.py b/server.py index ff2d1db2..1aa9fc04 100644 --- a/server.py +++ b/server.py @@ -1,58 +1,20 @@ -import os -import shutil -import warnings -from pathlib import Path - -from modules import shared, ui # ui must be imported early to avoid circular imports -from modules.image_models import load_image_model -from modules.logging_colors import logger -from modules.prompts import load_prompt - -# Set up Gradio temp directory path -gradio_temp_path = shared.user_data_dir / 'cache' / 'gradio' -shutil.rmtree(gradio_temp_path, ignore_errors=True) -gradio_temp_path.mkdir(parents=True, exist_ok=True) - -# Set environment variables -os.environ.update({ - 'GRADIO_ANALYTICS_ENABLED': 'False', - 'BITSANDBYTES_NOWELCOME': '1', - 'GRADIO_TEMP_DIR': str(gradio_temp_path) -}) - -warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') -warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated') -warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict') -warnings.filterwarnings('ignore', category=UserWarning, message='The value passed into gr.Dropdown()') -warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_names" has conflict') - -import gradio as gr - import os import signal import sys import time +import warnings from functools import partial +from pathlib import Path from threading import Lock, Thread import yaml +from modules import shared, utils +from modules.image_models import load_image_model +from modules.logging_colors import logger +from modules.prompts import load_prompt + import modules.extensions as extensions_module -from modules import ( - training, - ui, - ui_chat, - ui_default, - ui_file_saving, - ui_image_generation, - ui_model_menu, - ui_notebook, - ui_parameters, - ui_session, - utils -) -from modules.chat import generate_pfp_cache -from modules.extensions import apply_extensions from modules.LoRA import add_lora_to_model from modules.models import load_model, unload_model_if_idle from modules.models_settings import ( @@ -61,10 +23,20 @@ from modules.models_settings import ( update_model_parameters ) from modules.shared import do_cmd_flags_warnings -from modules.utils import gradio + +os.environ['BITSANDBYTES_NOWELCOME'] = '1' + +warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') +warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated') +warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict') +warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_names" has conflict') def signal_handler(sig, frame): + # On second Ctrl+C, force an immediate exit + signal.signal(signal.SIGINT, signal.SIG_DFL) + signal.signal(signal.SIGTERM, signal.SIG_DFL) + logger.info("Received Ctrl+C. Shutting down Text Generation Web UI gracefully.") # Explicitly stop LlamaServer to avoid __del__ cleanup issues during shutdown @@ -83,6 +55,37 @@ signal.signal(signal.SIGTERM, signal_handler) def create_interface(): + import shutil + + import gradio as gr + + from modules import ( + training, + ui, + ui_chat, + ui_default, + ui_file_saving, + ui_image_generation, + ui_model_menu, + ui_notebook, + ui_parameters, + ui_session, + ) + from modules.chat import generate_pfp_cache + from modules.extensions import apply_extensions + from modules.utils import gradio + + warnings.filterwarnings('ignore', category=UserWarning, message='The value passed into gr.Dropdown()') + + # Set up Gradio temp directory path + gradio_temp_path = shared.user_data_dir / 'cache' / 'gradio' + shutil.rmtree(gradio_temp_path, ignore_errors=True) + gradio_temp_path.mkdir(parents=True, exist_ok=True) + os.environ.update({ + 'GRADIO_ANALYTICS_ENABLED': 'False', + 'GRADIO_TEMP_DIR': str(gradio_temp_path) + }) + title = 'Text Generation Web UI' # Password authentication @@ -215,6 +218,10 @@ def create_interface(): shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, gradio(ui.list_interface_input_elements()), show_progress=False) + # Sync theme_state with the actual client-side theme so that + # autosave always writes the correct dark_theme value. + shared.gradio['interface'].load(None, None, gradio('theme_state'), js='() => document.body.classList.contains("dark") ? "dark" : "light"') + extensions_module.create_extensions_tabs() # Extensions tabs extensions_module.create_extensions_block() # Extensions block diff --git a/start_windows.bat b/start_windows.bat index dd096760..8da6986f 100755 --- a/start_windows.bat +++ b/start_windows.bat @@ -5,6 +5,7 @@ setlocal enabledelayedexpansion set PYTHONNOUSERSITE=1 set PYTHONPATH= set PYTHONHOME= +set PYTHONUTF8=1 cd /D "%~dp0" diff --git a/user_data/presets/Instruct.yaml b/user_data/presets/Instruct.yaml deleted file mode 100644 index 142fcd82..00000000 --- a/user_data/presets/Instruct.yaml +++ /dev/null @@ -1 +0,0 @@ -min_p: 0.2 diff --git a/user_data/presets/Qwen3 - No Thinking.yaml b/user_data/presets/Qwen3 - No Thinking.yaml deleted file mode 100644 index b1c1e03c..00000000 --- a/user_data/presets/Qwen3 - No Thinking.yaml +++ /dev/null @@ -1,3 +0,0 @@ -temperature: 0.7 -top_p: 0.8 -top_k: 20 diff --git a/user_data/presets/Qwen3 - Thinking.yaml b/user_data/presets/Qwen3 - Thinking.yaml deleted file mode 100644 index cb2942f9..00000000 --- a/user_data/presets/Qwen3 - Thinking.yaml +++ /dev/null @@ -1,3 +0,0 @@ -temperature: 0.6 -top_p: 0.95 -top_k: 20 diff --git a/user_data/presets/Top-P.yaml b/user_data/presets/Top-P.yaml new file mode 100644 index 00000000..f39e148f --- /dev/null +++ b/user_data/presets/Top-P.yaml @@ -0,0 +1 @@ +top_p: 0.95 diff --git a/user_data/presets/min_p.yaml b/user_data/presets/min_p.yaml deleted file mode 100644 index b8ebc95f..00000000 --- a/user_data/presets/min_p.yaml +++ /dev/null @@ -1 +0,0 @@ -min_p: 0.05 diff --git a/user_data/tools/calculate.py b/user_data/tools/calculate.py new file mode 100644 index 00000000..94f74c41 --- /dev/null +++ b/user_data/tools/calculate.py @@ -0,0 +1,52 @@ +import ast +import operator + +OPERATORS = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.Pow: operator.pow, + ast.Mod: operator.mod, + ast.USub: operator.neg, +} + + +def _eval(node): + if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)): + return node.value + elif isinstance(node, ast.BinOp) and type(node.op) in OPERATORS: + left = _eval(node.left) + right = _eval(node.right) + if isinstance(node.op, ast.Pow) and isinstance(right, (int, float)) and abs(right) > 10000: + raise ValueError("Exponent too large (max 10000)") + return OPERATORS[type(node.op)](left, right) + elif isinstance(node, ast.UnaryOp) and type(node.op) in OPERATORS: + return OPERATORS[type(node.op)](_eval(node.operand)) + raise ValueError(f"Unsupported expression") + + +tool = { + "type": "function", + "function": { + "name": "calculate", + "description": "Evaluate a math expression. Supports +, -, *, /, **, %.", + "parameters": { + "type": "object", + "properties": { + "expression": {"type": "string", "description": "The math expression to evaluate (e.g. '2 * (3 + 4)')."}, + }, + "required": ["expression"] + } + } +} + + +def execute(arguments): + expr = arguments.get("expression", "") + try: + tree = ast.parse(expr, mode='eval') + result = _eval(tree.body) + return {"expression": expr, "result": result} + except Exception as e: + return {"error": str(e)} diff --git a/user_data/tools/fetch_webpage.py b/user_data/tools/fetch_webpage.py new file mode 100644 index 00000000..ca3e7331 --- /dev/null +++ b/user_data/tools/fetch_webpage.py @@ -0,0 +1,30 @@ +from modules.web_search import download_web_page, truncate_content_by_tokens + +tool = { + "type": "function", + "function": { + "name": "fetch_webpage", + "description": "Fetch and read the contents of a web page given its URL. Returns the page content as plain text.", + "parameters": { + "type": "object", + "properties": { + "url": {"type": "string", "description": "The URL of the web page to fetch."}, + "max_tokens": {"type": "integer", "description": "Maximum number of tokens in the returned content (default: 2048)."}, + }, + "required": ["url"] + } + } +} + + +def execute(arguments): + url = arguments.get("url", "") + max_tokens = arguments.get("max_tokens", 2048) + if not url: + return {"error": "No URL provided."} + + content = download_web_page(url, include_links=True) + if not content or not content.strip(): + return {"error": f"Failed to fetch content from {url}"} + + return {"url": url, "content": truncate_content_by_tokens(content, max_tokens=max_tokens)} diff --git a/user_data/tools/get_datetime.py b/user_data/tools/get_datetime.py new file mode 100644 index 00000000..f0a92777 --- /dev/null +++ b/user_data/tools/get_datetime.py @@ -0,0 +1,18 @@ +from datetime import datetime + +tool = { + "type": "function", + "function": { + "name": "get_datetime", + "description": "Get the current date and time.", + "parameters": { + "type": "object", + "properties": {}, + } + } +} + + +def execute(arguments): + now = datetime.now() + return {"date": now.strftime("%Y-%m-%d"), "time": now.strftime("%I:%M %p")} diff --git a/user_data/tools/roll_dice.py b/user_data/tools/roll_dice.py new file mode 100644 index 00000000..4af38ddc --- /dev/null +++ b/user_data/tools/roll_dice.py @@ -0,0 +1,23 @@ +import random + +tool = { + "type": "function", + "function": { + "name": "roll_dice", + "description": "Roll one or more dice with the specified number of sides.", + "parameters": { + "type": "object", + "properties": { + "count": {"type": "integer", "description": "Number of dice to roll.", "default": 1}, + "sides": {"type": "integer", "description": "Number of sides per die.", "default": 20}, + }, + } + } +} + + +def execute(arguments): + count = max(1, min(arguments.get("count", 1), 1000)) + sides = max(2, min(arguments.get("sides", 20), 1000)) + rolls = [random.randint(1, sides) for _ in range(count)] + return {"rolls": rolls, "total": sum(rolls)} diff --git a/user_data/tools/web_search.py b/user_data/tools/web_search.py new file mode 100644 index 00000000..6c2b0f0b --- /dev/null +++ b/user_data/tools/web_search.py @@ -0,0 +1,27 @@ +from modules.web_search import perform_web_search + +tool = { + "type": "function", + "function": { + "name": "web_search", + "description": "Search the web using DuckDuckGo and return a list of result titles and URLs. Use fetch_webpage to read the contents of a specific result.", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "The search query."}, + }, + "required": ["query"] + } + } +} + + +def execute(arguments): + query = arguments.get("query", "") + results = perform_web_search(query, num_pages=None, fetch_content=False) + output = [] + for r in results: + if r: + output.append({"title": r["title"], "url": r["url"]}) + + return output if output else [{"error": "No results found."}]