Merge pull request #7425 from oobabooga/dev

Merge dev branch
This commit is contained in:
oobabooga 2026-03-16 12:51:33 -03:00 committed by GitHub
commit 88a318894c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
71 changed files with 3182 additions and 1354 deletions

View file

@ -148,11 +148,11 @@ jobs:
# 6. Create archive
cd ..
if [[ "$RUNNER_OS" == "Windows" ]]; then
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm.zip"
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm7.2.zip"
echo "Creating archive: $ARCHIVE_NAME"
powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ARCHIVE_NAME"
else
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm.tar.gz"
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm7.2.tar.gz"
echo "Creating archive: $ARCHIVE_NAME"
tar czf "$ARCHIVE_NAME" "text-generation-webui-${VERSION_CLEAN}"
fi

157
README.md
View file

@ -13,7 +13,7 @@
# Text Generation Web UI
A Gradio web UI for running Large Language Models locally. 100% private, offline, and free.
A Gradio web UI for running Large Language Models locally. 100% private and offline. Supports text generation, vision, tool-calling, training, image generation, and more.
[Try the Deep Reason extension](https://oobabooga.gumroad.com/l/deep_reason)
@ -23,22 +23,21 @@ A Gradio web UI for running Large Language Models locally. 100% private, offline
## Features
- Supports multiple local text generation backends, including [llama.cpp](https://github.com/ggerganov/llama.cpp), [Transformers](https://github.com/huggingface/transformers), [ExLlamaV3](https://github.com/turboderp-org/exllamav3), and [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) (the latter via its own [Dockerfile](https://github.com/oobabooga/text-generation-webui/blob/main/docker/TensorRT-LLM/Dockerfile)).
- Easy setup: Choose between **portable builds** (zero setup, just unzip and run) for GGUF models on Windows/Linux/macOS, or the one-click installer that creates a self-contained `installer_files` directory.
- **Multiple backends**: [llama.cpp](https://github.com/ggerganov/llama.cpp), [Transformers](https://github.com/huggingface/transformers), [ExLlamaV3](https://github.com/turboderp-org/exllamav3), and [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM). Switch between backends and models without restarting.
- **File attachments**: Upload text files, PDF documents, and .docx documents to talk about their contents.
- **Vision (multimodal)**: Attach images to messages for visual understanding ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Multimodal-Tutorial)).
- **Tool-calling**: Models can call custom functions during chat — web search, page fetching, math, and more. Each tool is a single `.py` file, easy to create and extend ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Tool-Calling-Tutorial)).
- **OpenAI-compatible API**: Chat and Completions endpoints with tool-calling support. Use as a local drop-in replacement for the OpenAI API ([examples](https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples)).
- **Training**: Fine-tune LoRAs on multi-turn chat or raw text datasets. Supports resuming interrupted runs ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/05-%E2%80%90-Training-Tab)).
- **Image generation**: A dedicated tab for `diffusers` models like **Z-Image-Turbo**. Features 4-bit/8-bit quantization and a persistent gallery with metadata ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Image-Generation-Tutorial)).
- **Easy setup**: [Portable builds](https://github.com/oobabooga/text-generation-webui/releases) (zero setup, just unzip and run) for GGUF models on Windows/Linux/macOS, or a one-click installer for the full feature set.
- 100% offline and private, with zero telemetry, external resources, or remote update requests.
- `instruct` mode for instruction-following (like ChatGPT), and `chat-instruct`/`chat` modes for talking to custom characters. Prompts are automatically formatted with Jinja2 templates.
- **File attachments**: Upload text files, PDF documents, and .docx documents to talk about their contents.
- **Vision (multimodal models)**: Attach images to messages for visual understanding ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Multimodal-Tutorial)).
- **Image generation**: A dedicated tab for `diffusers` models like **Z-Image-Turbo**. Features 4-bit/8-bit quantization and a persistent gallery with metadata ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Image-Generation-Tutorial)).
- **Web search**: Optionally search the internet with LLM-generated queries to add context to the conversation.
- **Training**: Fine-tune LoRAs on multi-turn chat or raw text datasets. Easy to use, good defaults, and supports resuming interrupted runs ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/05-%E2%80%90-Training-Tab)).
- Edit messages, navigate between message versions, and branch conversations at any point.
- Switch between different models in the UI without restarting.
- Free-form text generation in the Notebook tab without being limited to chat turns.
- Multiple sampling parameters and generation options for sophisticated text generation control.
- Aesthetic UI with dark and light themes.
- Syntax highlighting for code blocks and LaTeX rendering for mathematical expressions.
- OpenAI-compatible API with Chat and Completions endpoints, including tool-calling support see [examples](https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples).
- Extension support, with numerous built-in and user-contributed extensions available. See the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/07-%E2%80%90-Extensions) and [extensions directory](https://github.com/oobabooga/text-generation-webui-extensions) for details.
## How to install
@ -47,10 +46,11 @@ A Gradio web UI for running Large Language Models locally. 100% private, offline
No installation needed just download, unzip and run. All dependencies included.
Compatible with GGUF (llama.cpp) models on Windows, Linux, and macOS. [Check what models fit your hardware](https://huggingface.co/spaces/oobabooga/accurate-gguf-vram-calculator).
Download from here: **https://github.com/oobabooga/text-generation-webui/releases**
- Builds are provided for Linux, Windows, and macOS, with options for CUDA, Vulkan, ROCm, and CPU-only.
- Compatible with GGUF (llama.cpp) models.
#### Option 2: Manual portable install with venv
Very fast setup that should work on any Python 3.9+:
@ -81,7 +81,7 @@ deactivate
#### Option 3: One-click installer
For users who need additional backends (ExLlamaV3, Transformers) or extensions (TTS, voice input, translation, etc). Requires ~10GB disk space and downloads PyTorch.
For users who need additional backends (ExLlamaV3, Transformers), training, image generation, or extensions (TTS, voice input, translation, etc). Requires ~10GB disk space and downloads PyTorch.
1. Clone the repository, or [download its source code](https://github.com/oobabooga/text-generation-webui/archive/refs/heads/main.zip) and extract it.
2. Run the startup script for your OS: `start_windows.bat`, `start_linux.sh`, or `start_macos.sh`.
@ -146,7 +146,7 @@ conda activate textgen
|--------|---------|---------|
| Linux/WSL | NVIDIA | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128` |
| Linux/WSL | CPU only | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cpu` |
| Linux | AMD | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/rocm6.4` |
| Linux | AMD | `pip3 install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.9.1%2Brocm7.2.0.lw.git7e1940d4-cp313-cp313-linux_x86_64.whl` |
| MacOS + MPS | Any | `pip3 install torch==2.9.1` |
| Windows | NVIDIA | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128` |
| Windows | CPU only | `pip3 install torch==2.9.1` |
@ -201,7 +201,7 @@ ln -s docker/{nvidia/Dockerfile,nvidia/docker-compose.yml,.dockerignore} .
For AMD GPU:
ln -s docker/{amd/Dockerfile,amd/docker-compose.yml,.dockerignore} .
For Intel GPU:
ln -s docker/{intel/Dockerfile,amd/docker-compose.yml,.dockerignore} .
ln -s docker/{intel/Dockerfile,intel/docker-compose.yml,.dockerignore} .
For CPU only
ln -s docker/{cpu/Dockerfile,cpu/docker-compose.yml,.dockerignore} .
cp docker/.env.example .env
@ -236,20 +236,24 @@ List of command-line flags
</summary>
```txt
usage: server.py [-h] [--multi-user] [--model MODEL] [--lora LORA [LORA ...]] [--model-dir MODEL_DIR] [--lora-dir LORA_DIR] [--model-menu] [--settings SETTINGS]
usage: server.py [-h] [--user-data-dir USER_DATA_DIR] [--multi-user] [--model MODEL] [--lora LORA [LORA ...]] [--model-dir MODEL_DIR] [--lora-dir LORA_DIR] [--model-menu] [--settings SETTINGS]
[--extensions EXTENSIONS [EXTENSIONS ...]] [--verbose] [--idle-timeout IDLE_TIMEOUT] [--image-model IMAGE_MODEL] [--image-model-dir IMAGE_MODEL_DIR] [--image-dtype {bfloat16,float16}]
[--image-attn-backend {flash_attention_2,sdpa}] [--image-cpu-offload] [--image-compile] [--image-quant {none,bnb-8bit,bnb-4bit,torchao-int8wo,torchao-fp4,torchao-float8wo}]
[--loader LOADER] [--ctx-size N] [--cache-type N] [--model-draft MODEL_DRAFT] [--draft-max DRAFT_MAX] [--gpu-layers-draft GPU_LAYERS_DRAFT] [--device-draft DEVICE_DRAFT]
[--ctx-size-draft CTX_SIZE_DRAFT] [--spec-type {none,ngram-mod,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-cache}] [--spec-ngram-size-n SPEC_NGRAM_SIZE_N]
[--spec-ngram-size-m SPEC_NGRAM_SIZE_M] [--spec-ngram-min-hits SPEC_NGRAM_MIN_HITS] [--gpu-layers N] [--cpu-moe] [--mmproj MMPROJ] [--streaming-llm] [--tensor-split TENSOR_SPLIT]
[--row-split] [--no-mmap] [--mlock] [--no-kv-offload] [--batch-size BATCH_SIZE] [--ubatch-size UBATCH_SIZE] [--threads THREADS] [--threads-batch THREADS_BATCH] [--numa]
[--extra-flags EXTRA_FLAGS] [--cpu] [--cpu-memory CPU_MEMORY] [--disk] [--disk-cache-dir DISK_CACHE_DIR] [--load-in-8bit] [--bf16] [--no-cache] [--trust-remote-code]
[--force-safetensors] [--no_use_fast] [--attn-implementation IMPLEMENTATION] [--load-in-4bit] [--use_double_quant] [--compute_dtype COMPUTE_DTYPE] [--quant_type QUANT_TYPE]
[--gpu-split GPU_SPLIT] [--enable-tp] [--tp-backend TP_BACKEND] [--cfg-cache] [--cpp-runner]
[--alpha_value ALPHA_VALUE] [--rope_freq_base ROPE_FREQ_BASE] [--compress_pos_emb COMPRESS_POS_EMB] [--listen] [--listen-port LISTEN_PORT] [--listen-host LISTEN_HOST] [--share]
[--parallel PARALLEL] [--fit-target FIT_TARGET] [--extra-flags EXTRA_FLAGS] [--cpu] [--cpu-memory CPU_MEMORY] [--disk] [--disk-cache-dir DISK_CACHE_DIR] [--load-in-8bit] [--bf16]
[--no-cache] [--trust-remote-code] [--force-safetensors] [--no_use_fast] [--attn-implementation IMPLEMENTATION] [--load-in-4bit] [--use_double_quant] [--compute_dtype COMPUTE_DTYPE]
[--quant_type QUANT_TYPE] [--gpu-split GPU_SPLIT] [--enable-tp] [--tp-backend TP_BACKEND] [--cfg-cache] [--listen] [--listen-port LISTEN_PORT] [--listen-host LISTEN_HOST] [--share]
[--auto-launch] [--gradio-auth GRADIO_AUTH] [--gradio-auth-path GRADIO_AUTH_PATH] [--ssl-keyfile SSL_KEYFILE] [--ssl-certfile SSL_CERTFILE] [--subpath SUBPATH] [--old-colors]
[--portable] [--api] [--public-api] [--public-api-id PUBLIC_API_ID] [--api-port API_PORT] [--api-key API_KEY] [--admin-key ADMIN_KEY] [--api-enable-ipv6] [--api-disable-ipv4]
[--nowebui]
[--nowebui] [--temperature N] [--dynatemp-low N] [--dynatemp-high N] [--dynatemp-exponent N] [--smoothing-factor N] [--smoothing-curve N] [--min-p N] [--top-p N] [--top-k N]
[--typical-p N] [--xtc-threshold N] [--xtc-probability N] [--epsilon-cutoff N] [--eta-cutoff N] [--tfs N] [--top-a N] [--top-n-sigma N] [--adaptive-target N] [--adaptive-decay N]
[--dry-multiplier N] [--dry-allowed-length N] [--dry-base N] [--repetition-penalty N] [--frequency-penalty N] [--presence-penalty N] [--encoder-repetition-penalty N]
[--no-repeat-ngram-size N] [--repetition-penalty-range N] [--penalty-alpha N] [--guidance-scale N] [--mirostat-mode N] [--mirostat-tau N] [--mirostat-eta N]
[--do-sample | --no-do-sample] [--dynamic-temperature | --no-dynamic-temperature] [--temperature-last | --no-temperature-last] [--sampler-priority N] [--dry-sequence-breakers N]
[--enable-thinking | --no-enable-thinking] [--reasoning-effort N] [--chat-template-file CHAT_TEMPLATE_FILE]
Text Generation Web UI
@ -257,7 +261,8 @@ options:
-h, --help show this help message and exit
Basic settings:
--multi-user Multi-user mode. Chat histories are not saved or automatically loaded. Warning: this is likely not safe for sharing publicly.
--user-data-dir USER_DATA_DIR Path to the user data directory. Default: auto-detected.
--multi-user Multi-user mode. Chat histories are not saved or automatically loaded. Best suited for small trusted teams.
--model MODEL Name of the model to load by default.
--lora LORA [LORA ...] The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.
--model-dir MODEL_DIR Path to directory with all the models.
@ -280,12 +285,12 @@ Image model:
Quantization method for image model.
Model loader:
--loader LOADER Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav3,
TensorRT-LLM.
--loader LOADER Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav3, TensorRT-
LLM.
Context and cache:
--ctx-size N, --n_ctx N, --max_seq_len N Context size in tokens. llama.cpp: 0 = auto if gpu-layers is also -1.
--cache-type N, --cache_type N KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8).
--ctx-size, --n_ctx, --max_seq_len N Context size in tokens. 0 = auto for llama.cpp (requires gpu-layers=-1), 8192 for other loaders.
--cache-type, --cache_type N KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8).
Speculative decoding:
--model-draft MODEL_DRAFT Path to the draft model for speculative decoding.
@ -300,7 +305,7 @@ Speculative decoding:
--spec-ngram-min-hits SPEC_NGRAM_MIN_HITS Minimum n-gram hits for ngram-map speculative decoding.
llama.cpp:
--gpu-layers N, --n-gpu-layers N Number of layers to offload to the GPU. -1 = auto.
--gpu-layers, --n-gpu-layers N Number of layers to offload to the GPU. -1 = auto.
--cpu-moe Move the experts to the CPU (for MoE models).
--mmproj MMPROJ Path to the mmproj file for vision models.
--streaming-llm Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.
@ -314,13 +319,17 @@ llama.cpp:
--threads THREADS Number of threads to use.
--threads-batch THREADS_BATCH Number of threads to use for batches/prompt processing.
--numa Activate NUMA task allocation for llama.cpp.
--parallel PARALLEL Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set
ctx_size to 32768.
--fit-target FIT_TARGET Target VRAM margin per device for auto GPU layers, comma-separated list of values in MiB. A single value is broadcast across all devices.
Default: 1024.
--extra-flags EXTRA_FLAGS Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"
Transformers/Accelerate:
--cpu Use the CPU to generate text. Warning: Training on CPU is extremely slow.
--cpu-memory CPU_MEMORY Maximum CPU memory in GiB. Use this for CPU offloading.
--disk If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.
--disk-cache-dir DISK_CACHE_DIR Directory to save the disk cache to. Defaults to "user_data/cache".
--disk-cache-dir DISK_CACHE_DIR Directory to save the disk cache to.
--load-in-8bit Load the model with 8-bit precision (using bitsandbytes).
--bf16 Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
--no-cache Set use_cache to False while generating text. This reduces VRAM usage slightly, but it comes at a performance cost.
@ -341,14 +350,6 @@ ExLlamaV3:
--tp-backend TP_BACKEND The backend for tensor parallelism. Valid options: native, nccl. Default: native.
--cfg-cache Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.
TensorRT-LLM:
--cpp-runner Use the ModelRunnerCpp runner, which is faster than the default ModelRunner.
RoPE:
--alpha_value ALPHA_VALUE Positional embeddings alpha factor for NTK RoPE scaling. Use either this or compress_pos_emb, not both.
--rope_freq_base ROPE_FREQ_BASE If greater than 0, will be used instead of alpha_value. Those two are related by rope_freq_base = 10000 * alpha_value ^ (64 / 63).
--compress_pos_emb COMPRESS_POS_EMB Positional embeddings compression factor. Should be set to (context length) / (model's original context length). Equal to 1/rope_freq_scale.
Gradio:
--listen Make the web UI reachable from your local network.
--listen-port LISTEN_PORT The listening port that the server will use.
@ -365,7 +366,7 @@ Gradio:
API:
--api Enable the API extension.
--public-api Create a public URL for the API using Cloudfare.
--public-api Create a public URL for the API using Cloudflare.
--public-api-id PUBLIC_API_ID Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.
--api-port API_PORT The listening port for the API.
--api-key API_KEY API authentication key.
@ -373,28 +374,67 @@ API:
--api-enable-ipv6 Enable IPv6 for the API
--api-disable-ipv4 Disable IPv4 for the API
--nowebui Do not launch the Gradio UI. Useful for launching the API in standalone mode.
API generation defaults:
--temperature N Temperature
--dynatemp-low N Dynamic temperature low
--dynatemp-high N Dynamic temperature high
--dynatemp-exponent N Dynamic temperature exponent
--smoothing-factor N Smoothing factor
--smoothing-curve N Smoothing curve
--min-p N Min P
--top-p N Top P
--top-k N Top K
--typical-p N Typical P
--xtc-threshold N XTC threshold
--xtc-probability N XTC probability
--epsilon-cutoff N Epsilon cutoff
--eta-cutoff N Eta cutoff
--tfs N TFS
--top-a N Top A
--top-n-sigma N Top N Sigma
--adaptive-target N Adaptive target
--adaptive-decay N Adaptive decay
--dry-multiplier N DRY multiplier
--dry-allowed-length N DRY allowed length
--dry-base N DRY base
--repetition-penalty N Repetition penalty
--frequency-penalty N Frequency penalty
--presence-penalty N Presence penalty
--encoder-repetition-penalty N Encoder repetition penalty
--no-repeat-ngram-size N No repeat ngram size
--repetition-penalty-range N Repetition penalty range
--penalty-alpha N Penalty alpha
--guidance-scale N Guidance scale
--mirostat-mode N Mirostat mode
--mirostat-tau N Mirostat tau
--mirostat-eta N Mirostat eta
--do-sample, --no-do-sample Do sample
--dynamic-temperature, --no-dynamic-temperature Dynamic temperature
--temperature-last, --no-temperature-last Temperature last
--sampler-priority N Sampler priority
--dry-sequence-breakers N DRY sequence breakers
--enable-thinking, --no-enable-thinking Enable thinking
--reasoning-effort N Reasoning effort
--chat-template-file CHAT_TEMPLATE_FILE Path to a chat template file (.jinja, .jinja2, or .yaml) to use as the default instruction template for API requests. Overrides the model's
built-in template.
```
</details>
## Downloading models
Models should be placed in the folder `text-generation-webui/user_data/models`. They are usually downloaded from [Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads&search=gguf).
1. Download a GGUF model file from [Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads&search=gguf).
2. Place it in the `user_data/models` folder.
To check if a GGUF model will fit in your hardware before downloading it, you can use this tool I created:
That's it. The UI will detect it automatically.
[Accurate GGUF VRAM Calculator](https://huggingface.co/spaces/oobabooga/accurate-gguf-vram-calculator)
To check what will fit your GPU, you can use the [VRAM Calculator](https://huggingface.co/spaces/oobabooga/accurate-gguf-vram-calculator).
* GGUF models are a single file and should be placed directly into `user_data/models`. Example:
<details>
<summary>Other model types (Transformers, EXL3)</summary>
```
text-generation-webui
└── user_data
└── models
└── llama-2-13b-chat.Q4_K_M.gguf
```
* The remaining model types (like 16-bit Transformers models and EXL3 models) are made of several files and must be placed in a subfolder. Example:
Models that consist of multiple files (like 16-bit Transformers models and EXL3 models) should be placed in a subfolder inside `user_data/models`:
```
text-generation-webui
@ -404,31 +444,18 @@ text-generation-webui
├── config.json
├── generation_config.json
├── model-00001-of-00004.safetensors
├── model-00002-of-00004.safetensors
├── model-00003-of-00004.safetensors
├── model-00004-of-00004.safetensors
├── model.safetensors.index.json
├── special_tokens_map.json
├── ...
├── tokenizer_config.json
└── tokenizer.json
```
In both cases, you can use the "Model" tab of the UI to download the model from Hugging Face automatically. It is also possible to download it via the command-line with:
```
python download-model.py organization/model
```
Run `python download-model.py --help` to see all the options.
These formats require the one-click installer (not the portable build).
</details>
## Documentation
https://github.com/oobabooga/text-generation-webui/wiki
## Google Colab notebook
https://colab.research.google.com/github/oobabooga/text-generation-webui/blob/main/Colab-TextGen-GPU.ipynb
## Community
https://www.reddit.com/r/Oobabooga/

View file

@ -21,6 +21,7 @@ set INSTALL_ENV_DIR=%cd%\installer_files\env
set PYTHONNOUSERSITE=1
set PYTHONPATH=
set PYTHONHOME=
set PYTHONUTF8=1
set "CUDA_PATH=%INSTALL_ENV_DIR%"
set "CUDA_HOME=%CUDA_PATH%"

View file

@ -2,6 +2,7 @@
display: grid;
align-items: start;
grid-template-columns: 60px minmax(0, 1fr);
width: min(100%, calc(724px + 60px));
padding-bottom: 22px;
padding-top: 6px;
font-size: 18px;
@ -91,9 +92,6 @@
}
.message-body p {
margin-bottom: 0 !important;
font-size: 16px !important;
line-height: 1.5 !important;
color: #e0e0e0 !important; /* Light color for text */
}
@ -122,7 +120,7 @@
}
.message-body p {
font-size: 14px !important; /* Smaller text for mobile */
font-size: 14px !important;
}
.username {

View file

@ -4,6 +4,7 @@
display: grid;
align-items: start;
grid-template-columns: 60px minmax(0, 1fr);
width: min(100%, calc(724px + 60px + 90px));
padding-bottom: 21px;
padding-top: 7px;
font-size: 18px;
@ -86,10 +87,8 @@
border-radius: 20px;
}
.message-body p {
margin-bottom: 0 !important;
.message-body p, .message-body li {
font-size: 18px !important;
line-height: 1.428571429 !important;
color: rgb(243 244 246) !important;
text-shadow: 2px 2px 2px rgb(0 0 0);
font-weight: 500;
@ -127,7 +126,7 @@
padding-left: 0;
}
.message-body p {
.message-body p, .message-body li {
font-size: 16px !important;
}

View file

@ -19,4 +19,5 @@
padding-bottom: 1.5em;
padding-top: 0.5em;
grid-template-columns: 70px minmax(0, 1fr);
width: min(100%, calc(724px + 70px));
}

View file

@ -2,6 +2,7 @@
display: grid;
align-items: start;
grid-template-columns: 60px minmax(0, 1fr);
width: min(100%, calc(724px + 60px));
padding-bottom: 1.5em;
padding-top: 0.5em;
font-size: 15px;
@ -46,16 +47,10 @@
border-radius: 20px;
}
.message-body p {
font-size: 15px !important;
line-height: 22.5px !important;
.message-body p, .message-body li {
font-weight: 500;
}
.message-body p, .chat .message-body ul, .chat .message-body ol {
margin-bottom: 10px !important;
}
.dark .message-body p em {
color: rgb(138 138 138) !important;
}

View file

@ -1,4 +1,5 @@
.message {
width: min(100%, calc(724px + 60px));
padding-bottom: 22px;
padding-top: 3px;
font-size: 15px;
@ -60,8 +61,10 @@
text-align: right;
}
.dark .circle-bot + .text div, .dark .circle-bot + .text * {
color: #000;
.dark .circle-bot + .text div, .dark .circle-bot + .text *,
.dark .chat .message .circle-bot + .text .message-body :is(h1, h2, h3, h4, h5, h6),
.dark .chat .message .circle-bot + .text .message-body a {
color: #000 !important;
}
.text {
@ -76,19 +79,14 @@
font-weight: bold;
}
.message-body {
}
.message-body img {
max-width: 300px;
max-height: 300px;
border-radius: 20px;
}
.message-body p {
margin-bottom: 0 !important;
.message-body p, .message-body li {
font-size: 15px !important;
line-height: 1.428571429 !important;
font-weight: 500;
}

View file

@ -1,5 +1,6 @@
.message {
display: block;
width: min(100%, 724px);
padding-top: 0;
padding-bottom: 21px;
font-size: 15px;
@ -77,14 +78,8 @@
border-radius: 12px;
}
.message-body p {
.message-body p, .message-body li {
font-size: 15px !important;
line-height: 1.4 !important;
font-weight: 400;
}
.message-body p:first-child {
margin-top: 0 !important;
}
.dark .message-body p em {
@ -100,6 +95,3 @@
margin-top: 8px;
}
.message-body p, .chat .message-body ul, .chat .message-body ol {
margin-bottom: 10px !important;
}

View file

@ -78,7 +78,7 @@
.chat .user-message .text,
.chat .assistant-message .text {
max-width: 700px;
max-width: 724px;
margin-left: auto;
margin-right: auto;
}

View file

@ -400,7 +400,6 @@ audio {
}
.chat .message {
width: min(100%, 48rem);
margin-left: auto;
margin-right: auto;
text-align: start;
@ -431,10 +430,19 @@ audio {
font-size: 16px;
}
.dark .message-body :is(h1, h2, h3, h4, h5, h6) {
.dark .message-body h1,
.dark .message-body h2,
.dark .message-body h3,
.dark .message-body h4,
.dark .message-body h5,
.dark .message-body h6 {
color: white !important;
}
.dark .message-body blockquote {
border-left-color: rgb(255 255 255 / 30%);
}
.message-body h1 {
font-weight: 800;
font-size: 2.25em;
@ -831,9 +839,20 @@ audio {
}
}
.message-body ol, .message-body ul {
.message-body p, .message-body li {
line-height: 1.75 !important;
}
.message-body p, .message-body ul, .message-body ol {
margin: 1.25em 0 !important;
}
.message-body :is(p, ul, ol):first-child {
margin-top: 0 !important;
margin-bottom: 1.25em !important;
}
.message-body :is(p, ul, ol):last-child {
margin-bottom: 0 !important;
}
/* ----------------------------------------------
@ -1003,6 +1022,49 @@ audio {
padding-right: 0.5rem;
}
#new-chat-wrapper {
display: contents;
}
.new-chat-arrow {
cursor: pointer;
position: relative;
padding: 0;
margin-right: -15px;
height: 39.594px;
display: flex;
align-items: center;
}
.new-chat-menu {
display: none;
position: absolute;
top: 0;
left: 0;
padding-top: 1.2em;
z-index: var(--layer-top);
white-space: nowrap;
}
.new-chat-arrow:hover .new-chat-menu {
display: block;
}
.new-chat-menu-item {
cursor: pointer;
padding: var(--size-2);
background: var(--background-fill-primary);
box-shadow: var(--shadow-drop-lg);
border-radius: var(--container-radius);
color: var(--body-text-color);
font-size: var(--text-md);
font-weight: var(--button-large-text-weight);
}
.new-chat-menu-item:hover {
background: var(--background-fill-secondary);
}
#past-chats-row,
#chat-controls {
width: 260px;
@ -1373,7 +1435,6 @@ audio {
overflow-wrap: break-word;
max-height: 250px;
overflow-y: scroll;
contain: layout;
}
.chat .message-body .thinking-content p,
@ -1662,7 +1723,7 @@ button:focus {
.chat-parent {
/* Optimize for scrolling performance */
will-change: scroll-position;
contain: layout style paint;
contain: style paint;
/* Ensure GPU acceleration */
transform: translateZ(0);
@ -1802,6 +1863,15 @@ table {
border-collapse: collapse;
}
.table-wrapper {
overflow-x: auto;
}
.message-body :is(td, th) {
word-break: normal;
overflow-wrap: normal;
}
table, tr, td, th, thead {
border: 0;
}
@ -1814,3 +1884,86 @@ tr + tr th { border-top: 1px solid; }
thead + tbody tr:first-child td,
thead + tbody tr:first-child th { border-top: 1px solid; }
/* ------------------------------------------------
Tools CheckboxGroup - vertical DragDrop-like style
------------------------------------------------ */
/* "Refresh list" link in the Tools label */
.tools-refresh-link {
cursor: pointer;
}
/* Checkbox list container */
#tools-group {
padding: 0 !important;
border-width: 0 !important;
background: transparent !important;
min-height: 0 !important;
}
#tools-group .wrap {
display: flex;
flex-direction: column;
flex-wrap: nowrap;
gap: 4px;
padding: 0;
margin-top: var(--spacing-lg);
max-height: 350px;
overflow-y: auto;
}
/* Pretty scrollbar for the tools list */
#tools-group .wrap::-webkit-scrollbar {
width: 8px;
height: 8px;
}
#tools-group .wrap::-webkit-scrollbar-track {
background: transparent;
}
#tools-group .wrap::-webkit-scrollbar-thumb,
#tools-group .wrap::-webkit-scrollbar-thumb:hover {
background: var(--neutral-300);
border-radius: 30px;
}
.dark #tools-group .wrap::-webkit-scrollbar-thumb,
.dark #tools-group .wrap::-webkit-scrollbar-thumb:hover {
background: rgb(255 255 255 / 6.25%);
border-radius: 10px;
}
#tools-group .wrap::-webkit-scrollbar-corner {
background: transparent;
}
/* Each checkbox item */
#tools-group label {
display: flex;
align-items: center;
gap: 8px;
padding: 5px 8px;
border-radius: var(--radius-sm, 4px);
background: var(--block-background-fill);
border: 1px solid var(--border-color-primary);
color: var(--body-text-color);
font-size: var(--input-text-size);
font-weight: var(--input-text-weight);
cursor: pointer;
user-select: none;
transition: border-color 0.15s ease, background 0.15s ease;
box-shadow: none;
}
#tools-group label:hover {
border-color: var(--input-border-color-focus);
}
#tools-group label span {
flex: 1;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}

View file

@ -41,9 +41,6 @@ Options:
* **cpu_memory**: Maximum CPU memory in GiB to use for CPU offloading via the accelerate library. Whatever doesn't fit in the GPU or CPU will go to a disk cache if the "disk" checkbox is enabled.
* **compute_dtype**: Used when "load_in_4bit" is checked. I recommend leaving the default value.
* **quant_type**: Used when "load_in_4bit" is checked. I recommend leaving the default value.
* **alpha_value**: Used to extend the context length of a model with a minor loss in quality. I have measured 1.75 to be optimal for 1.5x context, and 2.5 for 2x context. That is, with alpha = 2.5 you can make a model with 4096 context length go to 8192 context length.
* **rope_freq_base**: Originally another way to write "alpha_value", it ended up becoming a necessary parameter for some models like CodeLlama, which was fine-tuned with this set to 1000000 and hence needs to be loaded with it set to 1000000 as well.
* **compress_pos_emb**: The first and original context-length extension method, discovered by [kaiokendev](https://kaiokendev.github.io/til). When set to 2, the context length is doubled, 3 and it's tripled, etc. It should only be used for models that have been fine-tuned with this parameter set to different than 1. For models that have not been tuned to have greater context length, alpha_value will lead to a smaller accuracy loss.
* **attn_implementation**: Choose the attention implementation. Valid options: `sdpa`, `eager`, `flash_attention_2`. The default (`sdpa`) works well in most cases; `flash_attention_2` may be useful for training.
* **cpu**: Loads the model in CPU mode using Pytorch. The model will be loaded in 32-bit precision, so a lot of RAM will be used. CPU inference with transformers is older than llama.cpp and it works, but it's a lot slower. Note: this parameter has a different interpretation in the llama.cpp loader (see above).
* **load_in_8bit**: Load the model in 8-bit precision using bitsandbytes. The 8-bit kernel in that library has been optimized for training and not inference, so load_in_8bit is slower than load_in_4bit (but more accurate).

View file

@ -4,7 +4,7 @@ A LoRA is tied to a specific model architecture — a LoRA trained on Llama 3 8B
### Quick Start
1. Load your base model (no LoRAs loaded).
1. Load your base model with the **Transformers** loader (no LoRAs loaded).
2. Open the **Training** tab > **Train LoRA**.
3. Pick a dataset and configure parameters (see [below](#parameters)).
4. Click **Start LoRA Training** and monitor the [loss](#loss).

View file

@ -0,0 +1,159 @@
## Supported models
The following models are supported:
- Qwen 3.5
- GPT-OSS
- Mistral Small / Devstral
- DeepSeek V3
- Kimi-K2
- MiniMax-M2.5
- GLM-5
- Llama 4
Other models that output tool calls as JSON (inside XML tags, code blocks, or plain JSON) are also supported through a generic fallback parser.
## Tool calling in the UI
### 1. Load a model with tool-calling support
Load a model with tool-calling support from the Model tab.
### 2. Select tools
In the chat sidebar, check the tools you want the model to use:
- **web_search** -- Search the web using DuckDuckGo.
- **fetch_webpage** -- Fetch the content of a URL.
- **calculate** -- Evaluate math expressions.
- **get_datetime** -- Get the current date and time.
- **roll_dice** -- Roll dice.
### 3. Chat
Send a message as usual. When the model decides it needs a tool, it will call it automatically. You will see each tool call and its result in a collapsible accordion inside the chat message.
The model may call multiple tools in sequence before giving its final answer.
## Writing custom tools
Each tool is a single `.py` file in `user_data/tools/`. It needs two things:
1. A `tool` dictionary that describes the function (name, description, parameters).
2. An `execute(arguments)` function that runs it and returns the result.
Here is a minimal example (`user_data/tools/get_datetime.py`):
```python
from datetime import datetime
tool = {
"type": "function",
"function": {
"name": "get_datetime",
"description": "Get the current date and time.",
"parameters": {
"type": "object",
"properties": {},
}
}
}
def execute(arguments):
now = datetime.now()
return {"date": now.strftime("%Y-%m-%d"), "time": now.strftime("%I:%M %p")}
```
An example with parameters (`user_data/tools/roll_dice.py`):
```python
import random
tool = {
"type": "function",
"function": {
"name": "roll_dice",
"description": "Roll one or more dice with the specified number of sides.",
"parameters": {
"type": "object",
"properties": {
"count": {"type": "integer", "description": "Number of dice to roll.", "default": 1},
"sides": {"type": "integer", "description": "Number of sides per die.", "default": 20},
},
}
}
}
def execute(arguments):
count = max(1, min(arguments.get("count", 1), 1000))
sides = max(2, min(arguments.get("sides", 20), 1000))
rolls = [random.randint(1, sides) for _ in range(count)]
return {"rolls": rolls, "total": sum(rolls)}
```
You can open the built-in tools in `user_data/tools/` for more examples.
## Tool calling over the API
Tool calling over the API follows the [OpenAI API](https://platform.openai.com/docs/guides/function-calling) convention. Define your tools, send them with your messages, and handle tool calls in a loop until the model gives a final answer.
```python
import json
import requests
url = "http://127.0.0.1:5000/v1/chat/completions"
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather for a given location.",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "City name"},
},
"required": ["location"]
}
}
}
]
def execute_tool(name, arguments):
if name == "get_weather":
return {"temperature": "14°C", "condition": "partly cloudy"}
return {"error": f"Unknown tool: {name}"}
messages = [{"role": "user", "content": "What's the weather like in Paris?"}]
for _ in range(10):
response = requests.post(url, json={"messages": messages, "tools": tools}).json()
choice = response["choices"][0]
if choice["finish_reason"] == "tool_calls":
messages.append({
"role": "assistant",
"content": choice["message"]["content"],
"tool_calls": choice["message"]["tool_calls"],
})
for tool_call in choice["message"]["tool_calls"]:
name = tool_call["function"]["name"]
arguments = json.loads(tool_call["function"]["arguments"])
result = execute_tool(name, arguments)
print(f"Tool call: {name}({arguments}) => {result}")
messages.append({
"role": "tool",
"tool_call_id": tool_call["id"],
"content": json.dumps(result),
})
else:
print(f"\nAssistant: {choice['message']['content']}")
break
```

View file

@ -11,8 +11,10 @@ from pydantic import ValidationError
from extensions.openai.errors import InvalidRequestError
from extensions.openai.typing import ToolDefinition
from extensions.openai.utils import debug_msg, getToolCallId, parseToolCall
from extensions.openai.utils import debug_msg
from modules.tool_parsing import get_tool_call_id, parse_tool_call, detect_tool_call_format
from modules import shared
from modules.reasoning import extract_reasoning
from modules.chat import (
generate_chat_prompt,
generate_chat_reply,
@ -37,17 +39,114 @@ def load_chat_template_file(filepath):
return text
def convert_logprobs_to_tiktoken(model, logprobs):
# more problems than it's worth.
# try:
# encoder = tiktoken.encoding_for_model(model)
# # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
# except KeyError:
# # assume native tokens if we can't find the tokenizer
# return logprobs
def _get_raw_logprob_entries(offset=0):
"""Get raw logprob entries from llama.cpp/ExLlamav3 backend, starting from offset.
return logprobs
Returns (new_entries, new_offset).
"""
if not hasattr(shared.model, 'last_completion_probabilities') or not shared.model.last_completion_probabilities:
return [], offset
all_entries = shared.model.last_completion_probabilities
new_entries = all_entries[offset:]
return new_entries, len(all_entries)
def _dict_to_logprob_entries(token_dict):
"""Convert a flat {token: logprob} dict (from LogprobProcessor) to raw entry format."""
if not token_dict:
return []
return [{"top_logprobs": [{"token": t, "logprob": lp} for t, lp in token_dict.items()]}]
def _parse_entry_top(entry):
"""Extract the top logprobs list from a raw entry, handling both key names."""
return entry.get('top_logprobs', entry.get('top_probs', []))
def format_chat_logprobs(entries):
"""Format logprob entries into OpenAI chat completions logprobs format.
Output: {"content": [{"token", "logprob", "bytes", "top_logprobs": [...]}]}
"""
if not entries:
return None
content = []
for entry in entries:
top = _parse_entry_top(entry)
if not top:
continue
chosen = top[0]
token_str = chosen.get('token', '')
token_logprob = chosen.get('logprob', chosen.get('prob', 0))
top_list = []
for item in top:
t = item.get('token', '')
lp = item.get('logprob', item.get('prob', 0))
top_list.append({
"token": t,
"logprob": lp,
"bytes": list(t.encode('utf-8')) if t else None
})
content.append({
"token": token_str,
"logprob": token_logprob,
"bytes": list(token_str.encode('utf-8')) if token_str else None,
"top_logprobs": top_list
})
return {"content": content, "refusal": None} if content else None
def format_completion_logprobs(entries):
"""Format logprob entries into OpenAI completions logprobs format.
Output: {"tokens", "token_logprobs", "top_logprobs": [{token: prob}], "text_offset"}
"""
if not entries:
return None
tokens = []
token_logprobs = []
top_logprobs = []
text_offset = []
offset = 0
for entry in entries:
top = _parse_entry_top(entry)
if not top:
continue
chosen = top[0]
token_str = chosen.get('token', '')
token_logprob = chosen.get('logprob', chosen.get('prob', 0))
tokens.append(token_str)
token_logprobs.append(token_logprob)
text_offset.append(offset)
offset += len(token_str)
top_dict = {}
for item in top:
t = item.get('token', '')
lp = item.get('logprob', item.get('prob', 0))
top_dict[t] = lp
top_logprobs.append(top_dict)
if not tokens:
return None
return {
"tokens": tokens,
"token_logprobs": token_logprobs,
"top_logprobs": top_logprobs,
"text_offset": text_offset
}
def process_parameters(body, is_legacy=False):
@ -72,7 +171,16 @@ def process_parameters(body, is_legacy=False):
elif isinstance(body['stop'], list):
generate_params['custom_stopping_strings'] = body['stop']
if shared.args.loader != 'llama.cpp':
# Resolve logprobs: for chat completions, logprobs is a bool and the count
# comes from top_logprobs. Normalize to an int for all backends.
logprobs = body.get('logprobs', None)
top_logprobs = body.get('top_logprobs', None)
if logprobs is True:
logprobs = max(top_logprobs, 1) if top_logprobs is not None else 5
generate_params['logprobs'] = logprobs
# For llama.cpp and ExLlamav3 native, logit_bias and logprobs are forwarded natively
if shared.args.loader not in ('llama.cpp', 'ExLlamav3'):
from transformers import LogitsProcessorList
from modules.transformers_loader import (
@ -85,13 +193,9 @@ def process_parameters(body, is_legacy=False):
if logit_bias: # {str: float, ...}
logits_processor = [LogitsBiasProcessor(logit_bias)]
logprobs = None # coming to chat eventually
if 'logprobs' in body:
logprobs = body.get('logprobs', 0) # maybe cap at topk? don't clamp 0-5.
if logprobs is not None and logprobs > 0:
generate_params['logprob_proc'] = LogprobProcessor(logprobs)
logits_processor.extend([generate_params['logprob_proc']])
else:
logprobs = None
if logits_processor: # requires logits_processor support
generate_params['logits_processor'] = LogitsProcessorList(logits_processor)
@ -137,12 +241,14 @@ def convert_history(history):
user_input = ""
user_input_last = True
system_message = ""
seen_non_system = False
for entry in history:
content = entry["content"]
role = entry["role"]
if role == "user":
seen_non_system = True
# Extract text content (images handled by model-specific code)
content = process_multimodal_content(content)
user_input = content
@ -154,6 +260,7 @@ def convert_history(history):
current_message = content
elif role == "assistant":
seen_non_system = True
meta = {}
tool_calls = entry.get("tool_calls")
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
@ -170,13 +277,22 @@ def convert_history(history):
else:
chat_dialogue.append(['', current_reply, '', meta])
elif role == "tool":
seen_non_system = True
user_input_last = False
meta = {}
if "tool_call_id" in entry:
meta["tool_call_id"] = entry["tool_call_id"]
chat_dialogue.append(['', '', content, meta])
elif role == "system":
system_message += f"\n{content}" if system_message else content
elif role in ("system", "developer"):
if not seen_non_system:
# Leading system messages go to custom_system_message (placed at top)
system_message += f"\n{content}" if system_message else content
else:
# Mid-conversation system messages: preserve position in history
if current_message:
chat_dialogue.append([current_message, '', '', {}])
current_message = ""
chat_dialogue.append([content, '', '', {"role": "system"}])
if not user_input_last:
user_input = ""
@ -202,6 +318,10 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
if 'tools' in body and body['tools'] is not None and isinstance(body['tools'], list) and len(body['tools']) > 0:
tools = validateTools(body['tools']) # raises InvalidRequestError if validation fails
tool_choice = body.get('tool_choice', None)
if tool_choice == "none":
tools = None # Disable tool detection entirely
messages = body['messages']
for m in messages:
if 'role' not in m:
@ -293,29 +413,55 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
requested_model = generate_params.pop('model')
logprob_proc = generate_params.pop('logprob_proc', None)
if logprob_proc:
logprob_proc.token_alternatives_history.clear()
chat_logprobs_offset = [0] # mutable for closure access in streaming
def chat_streaming_chunk(content, chunk_tool_calls=None):
def chat_streaming_chunk(content=None, chunk_tool_calls=None, include_role=False, reasoning_content=None):
# begin streaming
delta = {}
if include_role:
delta['role'] = 'assistant'
delta['refusal'] = None
if content is not None:
delta['content'] = content
if reasoning_content is not None:
delta['reasoning_content'] = reasoning_content
if chunk_tool_calls:
delta['tool_calls'] = chunk_tool_calls
chunk = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
"system_fingerprint": None,
resp_list: [{
"index": 0,
"finish_reason": None,
"delta": {'role': 'assistant', 'content': content, 'tool_calls': chunk_tool_calls},
"delta": delta,
"logprobs": None,
}],
}
if logprob_proc: # not official for chat yet
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
# else:
# chunk[resp_list][0]["logprobs"] = None
if logprob_proc:
entries = _dict_to_logprob_entries(logprob_proc.token_alternatives)
formatted = format_chat_logprobs(entries)
if formatted:
chunk[resp_list][0]["logprobs"] = formatted
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
entries, chat_logprobs_offset[0] = _get_raw_logprob_entries(chat_logprobs_offset[0])
if entries:
formatted = format_chat_logprobs(entries)
if formatted:
chunk[resp_list][0]["logprobs"] = formatted
return chunk
# Check if usage should be included in streaming chunks per OpenAI spec
stream_options = body.get('stream_options')
include_usage = bool(stream_options) and bool(stream_options.get('include_usage') if isinstance(stream_options, dict) else getattr(stream_options, 'include_usage', False))
# generate reply #######################################
if prompt_only:
prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_)
@ -323,75 +469,133 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
return
if stream:
yield chat_streaming_chunk('')
chunk = chat_streaming_chunk('', include_role=True)
if include_usage:
chunk['usage'] = None
yield chunk
generator = generate_chat_reply(
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
answer = ''
seen_content = ''
seen_reasoning = ''
tool_calls = []
end_last_tool_call = 0
supported_tools = [x["function"]["name"] for x in tools] if tools is not None else None
_tool_parsers = None
# Filter supported_tools when tool_choice specifies a particular function
if supported_tools and isinstance(tool_choice, dict):
specified_func = tool_choice.get("function", {}).get("name")
if specified_func and specified_func in supported_tools:
supported_tools = [specified_func]
if supported_tools is not None:
_template_str = generate_params.get('instruction_template_str', '') if generate_params.get('mode') == 'instruct' else generate_params.get('chat_template_str', '')
_tool_parsers, _, _ = detect_tool_call_format(_template_str)
for a in generator:
answer = a['internal'][-1][1]
if supported_tools is not None:
tool_call = parseToolCall(answer[end_last_tool_call:], supported_tools) if len(answer) > 0 else []
tool_call = parse_tool_call(answer[end_last_tool_call:], supported_tools, parsers=_tool_parsers) if len(answer) > 0 else []
if len(tool_call) > 0:
for tc in tool_call:
tc["id"] = getToolCallId()
tc["index"] = len(tool_calls)
tc["id"] = get_tool_call_id()
if stream:
tc["index"] = len(tool_calls)
tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"])
tool_calls.append(tc)
end_last_tool_call = len(answer)
if stream:
len_seen = len(seen_content)
new_content = answer[len_seen:]
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
continue
chunk = chat_streaming_chunk(new_content)
seen_content = answer
yield chunk
# stop generation if tool_calls were generated previously
# Stop generation before streaming content if tool_calls were detected,
# so that raw tool markup is not sent as content deltas.
if len(tool_calls) > 0:
break
if stream:
# Strip reasoning/thinking blocks so only final content is streamed.
# Reasoning is emitted separately as reasoning_content deltas.
reasoning, content = extract_reasoning(answer)
if reasoning is not None:
new_reasoning = reasoning[len(seen_reasoning):]
new_content = content[len(seen_content):]
else:
new_reasoning = None
new_content = answer[len(seen_content):]
if (not new_content and not new_reasoning) or chr(0xfffd) in (new_content or '') + (new_reasoning or ''):
continue
chunk = chat_streaming_chunk(
content=new_content if new_content else None,
reasoning_content=new_reasoning if new_reasoning else None,
)
if include_usage:
chunk['usage'] = None
if reasoning is not None:
seen_reasoning = reasoning
seen_content = content
else:
seen_content = answer
yield chunk
token_count = shared.model.last_prompt_token_count if hasattr(shared.model, 'last_prompt_token_count') else 0
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if len(tool_calls) > 0:
stop_reason = "tool_calls"
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
elif token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
stop_reason = "length"
else:
stop_reason = "stop"
if stream:
chunk = chat_streaming_chunk('', tool_calls)
chunk = chat_streaming_chunk(chunk_tool_calls=tool_calls)
chunk[resp_list][0]['finish_reason'] = stop_reason
chunk['usage'] = {
usage = {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
yield chunk
if include_usage:
chunk['usage'] = None
yield chunk
# Separate usage-only chunk with choices: [] per OpenAI spec
yield {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
"system_fingerprint": None,
resp_list: [],
"usage": usage
}
else:
yield chunk
else:
reasoning, content = extract_reasoning(answer)
message = {
"role": "assistant",
"refusal": None,
"content": None if tool_calls else content,
**({"reasoning_content": reasoning} if reasoning else {}),
**({"tool_calls": tool_calls} if tool_calls else {}),
}
resp = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
"system_fingerprint": None,
resp_list: [{
"index": 0,
"finish_reason": stop_reason,
"message": {"role": "assistant", "content": answer, **({"tool_calls": tool_calls} if tool_calls else {})},
"message": message,
"logprobs": None,
}],
"usage": {
"prompt_tokens": token_count,
@ -399,11 +603,19 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
"total_tokens": token_count + completion_token_count
}
}
if logprob_proc: # not official for chat yet
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
# else:
# resp[resp_list][0]["logprobs"] = None
if logprob_proc:
all_entries = []
for alt in logprob_proc.token_alternatives_history:
all_entries.extend(_dict_to_logprob_entries(alt))
formatted = format_chat_logprobs(all_entries)
if formatted:
resp[resp_list][0]["logprobs"] = formatted
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
raw = getattr(shared.model, 'last_completion_probabilities', None)
if raw:
formatted = format_chat_logprobs(raw)
if formatted:
resp[resp_list][0]["logprobs"] = formatted
yield resp
@ -411,7 +623,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_event=None):
object_type = 'text_completion'
created_time = int(time.time())
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
cmpl_id = "cmpl-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
prompt_str = 'context' if is_legacy else 'prompt'
@ -445,6 +657,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
generate_params['stop_event'] = stop_event
requested_model = generate_params.pop('model')
logprob_proc = generate_params.pop('logprob_proc', None)
if logprob_proc:
logprob_proc.token_alternatives_history.clear()
suffix = body['suffix'] if body['suffix'] else ''
echo = body['echo']
@ -456,6 +670,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
logger.info(f"Found {len(raw_images)} image(s) in request.")
generate_params['raw_images'] = raw_images
n_completions = body.get('n', 1) or 1
if not stream:
prompt_arg = body[prompt_str]
@ -469,6 +685,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
resp_list_data = []
total_completion_token_count = 0
total_prompt_token_count = 0
choice_index = 0
for idx, prompt in enumerate(prompt_arg, start=0):
if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], int):
@ -483,37 +700,59 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
prompt = decode(prompt)[0]
prefix = prompt if echo else ''
# generate reply #######################################
debug_msg({'prompt': prompt, 'generate_params': generate_params})
generator = generate_reply(prompt, generate_params, is_chat=False)
answer = ''
for a in generator:
answer = a
token_count = len(encode(prompt)[0])
total_prompt_token_count += token_count
completion_token_count = len(encode(answer)[0])
total_completion_token_count += completion_token_count
stop_reason = "stop"
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
respi = {
"index": idx,
"finish_reason": stop_reason,
"text": prefix + answer + suffix,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
}
original_seed = generate_params.get('seed', -1)
for _n in range(n_completions):
# Increment seed for each completion to ensure diversity (matches llama.cpp native behavior)
if original_seed >= 0:
generate_params['seed'] = original_seed + _n
resp_list_data.extend([respi])
if logprob_proc:
logprob_proc.token_alternatives_history.clear()
# generate reply #######################################
debug_msg({'prompt': prompt, 'generate_params': generate_params})
generator = generate_reply(prompt, generate_params, is_chat=False)
answer = ''
for a in generator:
answer = a
completion_token_count = len(encode(answer)[0])
total_completion_token_count += completion_token_count
stop_reason = "stop"
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
if logprob_proc:
all_entries = []
for alt in logprob_proc.token_alternatives_history:
all_entries.extend(_dict_to_logprob_entries(alt))
completion_logprobs = format_completion_logprobs(all_entries)
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
raw = getattr(shared.model, 'last_completion_probabilities', None)
completion_logprobs = format_completion_logprobs(raw)
else:
completion_logprobs = None
respi = {
"index": choice_index,
"finish_reason": stop_reason,
"text": prefix + answer + suffix,
"logprobs": completion_logprobs,
}
resp_list_data.append(respi)
choice_index += 1
resp = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
"system_fingerprint": None,
resp_list: resp_list_data,
"usage": {
"prompt_tokens": total_prompt_token_count,
@ -538,24 +777,41 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
prefix = prompt if echo else ''
token_count = len(encode(prompt)[0])
# Check if usage should be included in streaming chunks per OpenAI spec
stream_options = body.get('stream_options')
include_usage = bool(stream_options) and bool(stream_options.get('include_usage') if isinstance(stream_options, dict) else getattr(stream_options, 'include_usage', False))
cmpl_logprobs_offset = [0] # mutable for closure access in streaming
def text_streaming_chunk(content):
# begin streaming
if logprob_proc:
chunk_logprobs = format_completion_logprobs(_dict_to_logprob_entries(logprob_proc.token_alternatives))
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
entries, cmpl_logprobs_offset[0] = _get_raw_logprob_entries(cmpl_logprobs_offset[0])
chunk_logprobs = format_completion_logprobs(entries) if entries else None
else:
chunk_logprobs = None
chunk = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
"system_fingerprint": None,
resp_list: [{
"index": 0,
"finish_reason": None,
"text": content,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
"logprobs": chunk_logprobs,
}],
}
return chunk
yield text_streaming_chunk(prefix)
chunk = text_streaming_chunk(prefix)
if include_usage:
chunk['usage'] = None
yield chunk
# generate reply #######################################
debug_msg({'prompt': prompt, 'generate_params': generate_params})
@ -575,6 +831,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
seen_content = answer
chunk = text_streaming_chunk(new_content)
if include_usage:
chunk['usage'] = None
yield chunk
completion_token_count = len(encode(answer)[0])
@ -584,13 +842,27 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
chunk = text_streaming_chunk(suffix)
chunk[resp_list][0]["finish_reason"] = stop_reason
chunk["usage"] = {
usage = {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
yield chunk
if include_usage:
chunk['usage'] = None
yield chunk
# Separate usage-only chunk with choices: [] per OpenAI spec
yield {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
"system_fingerprint": None,
resp_list: [],
"usage": usage
}
else:
yield chunk
def chat_completions(body: dict, is_legacy: bool = False, stop_event=None) -> dict:

View file

@ -1,4 +1,4 @@
from modules import shared, ui
from modules import loaders, shared
from modules.logging_colors import logger
from modules.LoRA import add_lora_to_model
from modules.models import load_model, unload_model
@ -20,10 +20,14 @@ def list_models():
def list_models_openai_format():
"""Returns model list in OpenAI API format"""
model_names = get_available_models()
if shared.model_name and shared.model_name != 'None':
data = [model_info_dict(shared.model_name)]
else:
data = []
return {
"object": "list",
"data": [model_info_dict(name) for name in model_names]
"data": data
}
@ -50,7 +54,7 @@ def _load_model(data):
# parameters exposed in the UI. Never allow security-sensitive
# flags like trust_remote_code or extra_flags to be set via the API.
blocked_keys = {'extra_flags'}
allowed_keys = set(ui.list_model_elements()) - blocked_keys
allowed_keys = set(loaders.list_model_elements()) - blocked_keys
if args:
for k in args:
if k in allowed_keys and hasattr(shared.args, k):

View file

@ -21,6 +21,7 @@ import extensions.openai.completions as OAIcompletions
import extensions.openai.logits as OAIlogits
import extensions.openai.models as OAImodels
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.errors import OpenAIError
from extensions.openai.utils import _start_cloudflared
from modules import shared
from modules.logging_colors import logger
@ -94,6 +95,20 @@ app.add_middleware(
)
@app.exception_handler(OpenAIError)
async def openai_error_handler(request: Request, exc: OpenAIError):
error_type = "server_error" if exc.code >= 500 else "invalid_request_error"
return JSONResponse(
status_code=exc.code,
content={"error": {
"message": exc.message,
"type": error_type,
"param": getattr(exc, 'param', None),
"code": None
}}
)
@app.middleware("http")
async def validate_host_header(request: Request, call_next):
# Be strict about only approving access to localhost by default
@ -119,6 +134,12 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
is_legacy = "/generate" in path
if request_data.stream:
if (request_data.n or 1) > 1:
return JSONResponse(
status_code=400,
content={"error": {"message": "n > 1 is not supported with streaming.", "type": "invalid_request_error", "param": "n", "code": None}}
)
stop_event = threading.Event()
async def generator():
@ -130,6 +151,8 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
break
yield {"data": json.dumps(resp)}
yield {"data": "[DONE]"}
finally:
stop_event.set()
response.close()
@ -170,6 +193,8 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
break
yield {"data": json.dumps(resp)}
yield {"data": "[DONE]"}
finally:
stop_event.set()
response.close()
@ -433,10 +458,13 @@ def run_server():
# In the server configuration:
server_addrs = []
if os.environ.get('OPENEDAI_ENABLE_IPV6', shared.args.api_enable_ipv6):
server_addrs.append('[::]' if shared.args.listen else '[::1]')
if not os.environ.get('OPENEDAI_DISABLE_IPV4', shared.args.api_disable_ipv4):
server_addrs.append('0.0.0.0' if shared.args.listen else '127.0.0.1')
if shared.args.listen and shared.args.listen_host:
server_addrs.append(shared.args.listen_host)
else:
if os.environ.get('OPENEDAI_ENABLE_IPV6', shared.args.api_enable_ipv6):
server_addrs.append('[::]' if shared.args.listen else '[::1]')
if not os.environ.get('OPENEDAI_DISABLE_IPV4', shared.args.api_disable_ipv4):
server_addrs.append('0.0.0.0' if shared.args.listen else '127.0.0.1')
if not server_addrs:
raise Exception('you MUST enable IPv6 or IPv4 for the API to work')
@ -447,11 +475,11 @@ def run_server():
port,
shared.args.public_api_id,
max_attempts=3,
on_start=lambda url: logger.info(f'OpenAI-compatible API URL:\n\n{url}\n')
on_start=lambda url: logger.info(f'OpenAI-compatible API URL:\n\n{url}/v1\n')
)
else:
url_proto = 'https://' if (ssl_certfile and ssl_keyfile) else 'http://'
urls = [f'{url_proto}{addr}:{port}' for addr in server_addrs]
urls = [f'{url_proto}{addr}:{port}/v1' for addr in server_addrs]
if len(urls) > 1:
logger.info('OpenAI-compatible API URLs:\n\n' + '\n'.join(urls) + '\n')
else:

View file

@ -99,6 +99,10 @@ class ToolCall(BaseModel):
function: FunctionCall
class StreamOptions(BaseModel):
include_usage: bool | None = False
class CompletionRequestParams(BaseModel):
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
prompt: str | List[str] | None = Field(default=None, description="Text prompt for completion. Can also use 'messages' format for multimodal.")
@ -109,10 +113,11 @@ class CompletionRequestParams(BaseModel):
logit_bias: dict | None = None
logprobs: int | None = None
max_tokens: int | None = 512
n: int | None = Field(default=1, description="Unused parameter.")
n: int | None = Field(default=1, description="Number of completions to generate. Only supported without streaming.")
presence_penalty: float | None = shared.args.presence_penalty
stop: str | List[str] | None = None
stream: bool | None = False
stream_options: StreamOptions | None = None
suffix: str | None = None
temperature: float | None = shared.args.temperature
top_p: float | None = shared.args.top_p
@ -145,16 +150,27 @@ class ChatCompletionRequestParams(BaseModel):
function_call: str | dict | None = Field(default=None, description="Unused parameter.")
functions: List[dict] | None = Field(default=None, description="Unused parameter.")
tools: List[dict] | None = Field(default=None, description="Tools signatures passed via MCP.")
tool_choice: str | dict | None = Field(default=None, description="Controls tool use: 'auto', 'none', 'required', or {\"type\": \"function\", \"function\": {\"name\": \"...\"}}.")
logit_bias: dict | None = None
logprobs: bool | None = None
top_logprobs: int | None = None
max_tokens: int | None = None
max_completion_tokens: int | None = None
n: int | None = Field(default=1, description="Unused parameter.")
presence_penalty: float | None = shared.args.presence_penalty
stop: str | List[str] | None = None
stream: bool | None = False
stream_options: StreamOptions | None = None
temperature: float | None = shared.args.temperature
top_p: float | None = shared.args.top_p
user: str | None = Field(default=None, description="Unused parameter.")
@model_validator(mode='after')
def resolve_max_tokens(self):
if self.max_tokens is None and self.max_completion_tokens is not None:
self.max_tokens = self.max_completion_tokens
return self
mode: str = Field(default='instruct', description="Valid options: instruct, chat, chat-instruct.")
instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/user_data/instruction-templates. If not set, the correct template will be automatically obtained from the model metadata.")

View file

@ -1,8 +1,5 @@
import base64
import json
import os
import random
import re
import time
import traceback
from typing import Callable, Optional
@ -55,473 +52,3 @@ def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_star
time.sleep(3)
raise Exception('Could not start cloudflared.')
def getToolCallId() -> str:
letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789"
b = [random.choice(letter_bytes) for _ in range(8)]
return "call_" + "".join(b).lower()
def checkAndSanitizeToolCallCandidate(candidate_dict: dict, tool_names: list[str]):
# check if property 'function' exists and is a dictionary, otherwise adapt dict
if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str):
candidate_dict = {"type": "function", "function": candidate_dict}
if 'function' in candidate_dict and isinstance(candidate_dict['function'], str):
candidate_dict['name'] = candidate_dict['function']
del candidate_dict['function']
candidate_dict = {"type": "function", "function": candidate_dict}
if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict):
# check if 'name' exists within 'function' and is part of known tools
if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names:
candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value
# map property 'parameters' used by some older models to 'arguments'
if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]:
candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"]
del candidate_dict["function"]["parameters"]
return candidate_dict
return None
def _extractBalancedJson(text: str, start: int) -> str | None:
"""Extract a balanced JSON object from text starting at the given position.
Walks through the string tracking brace depth and string boundaries
to correctly handle arbitrary nesting levels.
"""
if start >= len(text) or text[start] != '{':
return None
depth = 0
in_string = False
escape_next = False
for i in range(start, len(text)):
c = text[i]
if escape_next:
escape_next = False
continue
if c == '\\' and in_string:
escape_next = True
continue
if c == '"':
in_string = not in_string
continue
if in_string:
continue
if c == '{':
depth += 1
elif c == '}':
depth -= 1
if depth == 0:
return text[start:i + 1]
return None
def _parseChannelToolCalls(answer: str, tool_names: list[str]):
"""Parse channel-based tool calls used by GPT-OSS and similar models.
Format:
<|channel|>commentary to=functions.func_name <|constrain|>json<|message|>{"arg": "value"}
"""
matches = []
for m in re.finditer(
r'<\|channel\|>commentary to=functions\.([^<\s]+)\s*(?:<\|constrain\|>json)?<\|message\|>',
answer
):
func_name = m.group(1).strip()
if func_name not in tool_names:
continue
json_str = _extractBalancedJson(answer, m.end())
if json_str is None:
continue
try:
arguments = json.loads(json_str)
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
except json.JSONDecodeError:
pass
return matches
def _parseBareNameToolCalls(answer: str, tool_names: list[str]):
"""Parse bare function-name style tool calls used by Mistral and similar models.
Format:
functionName{"arg": "value"}
Multiple calls are concatenated directly or separated by whitespace.
"""
matches = []
# Match tool name followed by opening brace, then extract balanced JSON
escaped_names = [re.escape(name) for name in tool_names]
pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{'
for match in re.finditer(pattern, answer):
text = match.group(0)
name = None
for n in tool_names:
if text.startswith(n):
name = n
break
if not name:
continue
brace_start = match.end() - 1
json_str = _extractBalancedJson(answer, brace_start)
if json_str is None:
continue
try:
arguments = json.loads(json_str)
matches.append({
"type": "function",
"function": {
"name": name,
"arguments": arguments
}
})
except json.JSONDecodeError:
pass
return matches
def _parseXmlParamToolCalls(answer: str, tool_names: list[str]):
"""Parse XML-parameter style tool calls used by Qwen3.5 and similar models.
Format:
<tool_call>
<function=function_name>
<parameter=param_name>value</parameter>
</function>
</tool_call>
"""
matches = []
for tc_match in re.finditer(r'<tool_call>\s*(.*?)\s*</tool_call>', answer, re.DOTALL):
tc_content = tc_match.group(1)
func_match = re.search(r'<function=([^>]+)>', tc_content)
if not func_match:
continue
func_name = func_match.group(1).strip()
if func_name not in tool_names:
continue
arguments = {}
for param_match in re.finditer(r'<parameter=([^>]+)>\s*(.*?)\s*</parameter>', tc_content, re.DOTALL):
param_name = param_match.group(1).strip()
param_value = param_match.group(2).strip()
try:
param_value = json.loads(param_value)
except (json.JSONDecodeError, ValueError):
pass # keep as string
arguments[param_name] = param_value
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
return matches
def _parseKimiToolCalls(answer: str, tool_names: list[str]):
"""Parse Kimi-K2-style tool calls using pipe-delimited tokens.
Format:
<|tool_calls_section_begin|>
<|tool_call_begin|>functions.func_name:index<|tool_call_argument_begin|>{"arg": "value"}<|tool_call_end|>
<|tool_calls_section_end|>
"""
matches = []
for m in re.finditer(
r'<\|tool_call_begin\|>\s*(?:functions\.)?(\S+?)(?::\d+)?\s*<\|tool_call_argument_begin\|>\s*',
answer
):
func_name = m.group(1).strip()
if func_name not in tool_names:
continue
json_str = _extractBalancedJson(answer, m.end())
if json_str is None:
continue
try:
arguments = json.loads(json_str)
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
except json.JSONDecodeError:
pass
return matches
def _parseMiniMaxToolCalls(answer: str, tool_names: list[str]):
"""Parse MiniMax-style tool calls using invoke/parameter XML tags.
Format:
<minimax:tool_call>
<invoke name="function_name">
<parameter name="param_name">value</parameter>
</invoke>
</minimax:tool_call>
"""
matches = []
for tc_match in re.finditer(r'<minimax:tool_call>\s*(.*?)\s*</minimax:tool_call>', answer, re.DOTALL):
tc_content = tc_match.group(1)
# Split on <invoke> to handle multiple parallel calls in one block
for invoke_match in re.finditer(r'<invoke\s+name="([^"]+)">(.*?)</invoke>', tc_content, re.DOTALL):
func_name = invoke_match.group(1).strip()
if func_name not in tool_names:
continue
invoke_body = invoke_match.group(2)
arguments = {}
for param_match in re.finditer(r'<parameter\s+name="([^"]+)">\s*(.*?)\s*</parameter>', invoke_body, re.DOTALL):
param_name = param_match.group(1).strip()
param_value = param_match.group(2).strip()
try:
param_value = json.loads(param_value)
except (json.JSONDecodeError, ValueError):
pass # keep as string
arguments[param_name] = param_value
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
return matches
def _parseDeepSeekToolCalls(answer: str, tool_names: list[str]):
"""Parse DeepSeek-style tool calls using fullwidth Unicode token delimiters.
Format:
<toolcallsbegin><toolcallbegin>func_name<toolsep>{"arg": "value"}<toolcallend><toolcallsend>
"""
matches = []
for m in re.finditer(
r'<tool▁call▁begin>\s*(\S+?)\s*<tool▁sep>\s*',
answer
):
func_name = m.group(1).strip()
if func_name not in tool_names:
continue
json_str = _extractBalancedJson(answer, m.end())
if json_str is None:
continue
try:
arguments = json.loads(json_str)
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
except json.JSONDecodeError:
pass
return matches
def _parseGlmToolCalls(answer: str, tool_names: list[str]):
"""Parse GLM-style tool calls using arg_key/arg_value XML pairs.
Format:
<tool_call>function_name
<arg_key>key1</arg_key>
<arg_value>value1</arg_value>
</tool_call>
"""
matches = []
for tc_match in re.finditer(r'<tool_call>\s*(.*?)\s*</tool_call>', answer, re.DOTALL):
tc_content = tc_match.group(1)
# First non-tag text is the function name
name_match = re.match(r'([^<\s]+)', tc_content.strip())
if not name_match:
continue
func_name = name_match.group(1).strip()
if func_name not in tool_names:
continue
# Extract arg_key/arg_value pairs
keys = [k.group(1).strip() for k in re.finditer(r'<arg_key>\s*(.*?)\s*</arg_key>', tc_content, re.DOTALL)]
vals = [v.group(1).strip() for v in re.finditer(r'<arg_value>\s*(.*?)\s*</arg_value>', tc_content, re.DOTALL)]
if len(keys) != len(vals):
continue
arguments = {}
for k, v in zip(keys, vals):
try:
v = json.loads(v)
except (json.JSONDecodeError, ValueError):
pass # keep as string
arguments[k] = v
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
return matches
def _parsePythonicToolCalls(answer: str, tool_names: list[str]):
"""Parse pythonic-style tool calls used by Llama 4 and similar models.
Format:
[func_name(param1="value1", param2="value2"), func_name2(...)]
"""
matches = []
# Match a bracketed list of function calls
bracket_match = re.search(r'\[([^\[\]]+)\]', answer)
if not bracket_match:
return matches
inner = bracket_match.group(1)
# Build pattern for known tool names
escaped_names = [re.escape(name) for name in tool_names]
name_pattern = '|'.join(escaped_names)
for call_match in re.finditer(
r'(' + name_pattern + r')\(([^)]*)\)',
inner
):
func_name = call_match.group(1)
params_str = call_match.group(2).strip()
arguments = {}
if params_str:
# Parse key="value" pairs, handling commas inside quoted values
for param_match in re.finditer(
r'(\w+)\s*=\s*("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\)]+)',
params_str
):
param_name = param_match.group(1)
param_value = param_match.group(2).strip()
# Strip surrounding quotes
if (param_value.startswith('"') and param_value.endswith('"')) or \
(param_value.startswith("'") and param_value.endswith("'")):
param_value = param_value[1:-1]
# Try to parse as JSON for numeric/bool/null values
try:
param_value = json.loads(param_value)
except (json.JSONDecodeError, ValueError):
pass
arguments[param_name] = param_value
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
return matches
def parseToolCall(answer: str, tool_names: list[str]):
matches = []
# abort on very short answers to save computation cycles
if len(answer) < 10:
return matches
# Check for DeepSeek-style tool calls (fullwidth Unicode token delimiters)
matches = _parseDeepSeekToolCalls(answer, tool_names)
if matches:
return matches
# Check for Kimi-K2-style tool calls (pipe-delimited tokens)
matches = _parseKimiToolCalls(answer, tool_names)
if matches:
return matches
# Check for channel-based tool calls (e.g. GPT-OSS format)
matches = _parseChannelToolCalls(answer, tool_names)
if matches:
return matches
# Check for MiniMax-style tool calls (invoke/parameter XML tags)
matches = _parseMiniMaxToolCalls(answer, tool_names)
if matches:
return matches
# Check for GLM-style tool calls (arg_key/arg_value XML pairs)
matches = _parseGlmToolCalls(answer, tool_names)
if matches:
return matches
# Check for XML-parameter style tool calls (e.g. Qwen3.5 format)
matches = _parseXmlParamToolCalls(answer, tool_names)
if matches:
return matches
# Check for bare function-name style tool calls (e.g. Mistral format)
matches = _parseBareNameToolCalls(answer, tool_names)
if matches:
return matches
# Check for pythonic-style tool calls (e.g. Llama 4 format)
matches = _parsePythonicToolCalls(answer, tool_names)
if matches:
return matches
# Define the regex pattern to find the JSON content wrapped in <function>, <tools>, <tool_call>, and other tags observed from various models
patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)</\1>"]
for pattern in patterns:
for match in re.finditer(pattern, answer, re.DOTALL):
# print(match.group(2))
if match.group(2) is None:
continue
# remove backtick wraps if present
candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip())
candidate = re.sub(r"```$", "", candidate.strip())
# unwrap inner tags
candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL)
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
if not candidate.strip().startswith("["):
candidate = "[" + candidate + "]"
candidates = []
try:
# parse the candidate JSON into a dictionary
candidates = json.loads(candidate)
if not isinstance(candidates, list):
candidates = [candidates]
except json.JSONDecodeError:
# Ignore invalid JSON silently
continue
for candidate_dict in candidates:
checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names)
if checked_candidate is not None:
matches.append(checked_candidate)
# last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags
if len(matches) == 0:
try:
candidate = answer
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
if not candidate.strip().startswith("["):
candidate = "[" + candidate + "]"
# parse the candidate JSON into a dictionary
candidates = json.loads(candidate)
if not isinstance(candidates, list):
candidates = [candidates]
for candidate_dict in candidates:
checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names)
if checked_candidate is not None:
matches.append(checked_candidate)
except json.JSONDecodeError:
# Ignore invalid JSON silently
pass
return matches

View file

@ -269,7 +269,49 @@ function removeLastClick() {
document.getElementById("Remove-last").click();
}
function autoScrollToBottom() {
if (!window.isScrolled) {
const chatParent = document.getElementById("chat")?.parentNode?.parentNode?.parentNode;
if (chatParent) {
const maxScroll = chatParent.scrollHeight - chatParent.clientHeight;
if (maxScroll > 0 && chatParent.scrollTop < maxScroll - 1) {
chatParent.scrollTop = maxScroll;
}
}
}
}
function updateInstructPadding() {
const chatElement = document.getElementById("chat");
if (chatElement && chatElement.getAttribute("data-mode") === "instruct") {
const messagesContainer = chatElement.querySelector(".messages");
const lastChild = messagesContainer?.lastElementChild;
const prevSibling = lastChild?.previousElementSibling;
if (lastChild && prevSibling && chatElement.offsetHeight > 0) {
let bufferHeight = Math.max(0, Math.max(window.innerHeight - 128 - 84, window.innerHeight - prevSibling.offsetHeight - 84) - lastChild.offsetHeight);
if (window.innerWidth <= 924) {
bufferHeight = Math.max(0, bufferHeight - 32);
}
messagesContainer.style.paddingBottom = `${bufferHeight}px`;
}
}
}
let pendingMorphdomData = null;
let morphdomRafId = null;
function handleMorphdomUpdate(data) {
pendingMorphdomData = data;
if (!morphdomRafId) {
morphdomRafId = requestAnimationFrame(() => {
morphdomRafId = null;
applyMorphdomUpdate(pendingMorphdomData);
pendingMorphdomData = null;
});
}
}
function applyMorphdomUpdate(data) {
// Determine target element and use it as query scope
var target_element, target_html;
if (data.last_message_only) {
@ -283,27 +325,21 @@ function handleMorphdomUpdate(data) {
const queryScope = target_element;
// Track open blocks
// Track open blocks and store their scroll positions
const openBlocks = new Set();
const scrollPositions = {};
queryScope.querySelectorAll(".thinking-block").forEach(block => {
const blockId = block.getAttribute("data-block-id");
// If block exists and is open, add to open set
if (blockId && block.hasAttribute("open")) {
openBlocks.add(blockId);
}
});
// Store scroll positions for any open blocks
const scrollPositions = {};
queryScope.querySelectorAll(".thinking-block[open]").forEach(block => {
const content = block.querySelector(".thinking-content");
const blockId = block.getAttribute("data-block-id");
if (content && blockId) {
const isAtBottom = Math.abs((content.scrollHeight - content.scrollTop) - content.clientHeight) < 5;
scrollPositions[blockId] = {
position: content.scrollTop,
isAtBottom: isAtBottom
};
const content = block.querySelector(".thinking-content");
if (content) {
const isAtBottom = Math.abs((content.scrollHeight - content.scrollTop) - content.clientHeight) < 5;
scrollPositions[blockId] = {
position: content.scrollTop,
isAtBottom: isAtBottom
};
}
}
});
@ -313,8 +349,8 @@ function handleMorphdomUpdate(data) {
{
onBeforeElUpdated: function(fromEl, toEl) {
// Preserve code highlighting
if (fromEl.tagName === "PRE" && fromEl.querySelector("code[data-highlighted]")) {
const fromCode = fromEl.querySelector("code");
if (fromEl.tagName === "PRE") {
const fromCode = fromEl.querySelector("code[data-highlighted]");
const toCode = toEl.querySelector("code");
if (fromCode && toCode && fromCode.textContent === toCode.textContent) {
@ -359,10 +395,23 @@ function handleMorphdomUpdate(data) {
}
);
// Syntax highlighting and LaTeX
if (window.doSyntaxHighlighting) {
window.doSyntaxHighlighting();
}
// Auto-scroll runs both before and after padding update.
// Before: so content growth isn't hidden by padding absorption.
// After: so padding-added space is also scrolled into view.
autoScrollToBottom();
updateInstructPadding();
autoScrollToBottom();
// Add toggle listeners for new blocks
queryScope.querySelectorAll(".thinking-block").forEach(block => {
if (!block._hasToggleListener) {
block.addEventListener("toggle", function(e) {
const wasScrolled = window.isScrolled;
if (this.open) {
const content = this.querySelector(".thinking-content");
if (content) {
@ -371,6 +420,12 @@ function handleMorphdomUpdate(data) {
}, 0);
}
}
autoScrollToBottom();
updateInstructPadding();
autoScrollToBottom();
// Restore scroll state so the browser's layout adjustment
// from the toggle doesn't disable auto-scroll
window.isScrolled = wasScrolled;
});
block._hasToggleListener = true;
}

View file

@ -2,6 +2,12 @@
// Main
// ------------------------------------------------
// Sync highlight.js theme with the actual Gradio theme
var defined_hljs_css = document.body.classList.contains("dark") ? "file/css/highlightjs/github-dark.min.css" : "file/css/highlightjs/github.min.css";
if (document.getElementById("highlight-css").getAttribute("href") !== defined_hljs_css) {
document.getElementById("highlight-css").setAttribute("href", defined_hljs_css);
}
let main_parent = document.getElementById("chat-tab").parentNode;
let extensions = document.getElementById("extensions");
@ -145,10 +151,13 @@ targetElement.classList.add("pretty_scrollbar");
targetElement.classList.add("chat-parent");
window.isScrolled = false;
let scrollTimeout;
let lastScrollTop = 0;
let lastScrollHeight = 0;
let lastClientHeight = 0;
targetElement.addEventListener("scroll", function() {
let diff = targetElement.scrollHeight - targetElement.clientHeight;
let isAtBottomNow = Math.abs(targetElement.scrollTop - diff) <= 10 || diff == 0;
let isAtBottomNow = Math.abs(targetElement.scrollTop - diff) <= 10 || diff <= 0;
// Add scrolling class to disable hover effects
if (window.isScrolled || !isAtBottomNow) {
@ -157,9 +166,12 @@ targetElement.addEventListener("scroll", function() {
if(isAtBottomNow) {
window.isScrolled = false;
} else {
} else if (targetElement.scrollTop < lastScrollTop && targetElement.scrollHeight >= lastScrollHeight && targetElement.clientHeight <= lastClientHeight) {
window.isScrolled = true;
}
lastScrollTop = targetElement.scrollTop;
lastScrollHeight = targetElement.scrollHeight;
lastClientHeight = targetElement.clientHeight;
// Clear previous timeout and set new one
clearTimeout(scrollTimeout);
@ -170,61 +182,28 @@ targetElement.addEventListener("scroll", function() {
});
// Create a MutationObserver instance
const observer = new MutationObserver(function(mutations) {
// Check if this is just the scrolling class being toggled
const isScrollingClassOnly = mutations.every(mutation =>
mutation.type === "attributes" &&
mutation.attributeName === "class" &&
mutation.target === targetElement
);
const observer = new MutationObserver(function() {
if (targetElement.classList.contains("_generating")) {
typing.parentNode.classList.add("visible-dots");
document.getElementById("stop").style.display = "flex";
document.getElementById("Generate").style.display = "none";
// If the user is near the bottom, ensure auto-scroll is enabled
// for the new reply. This catches cases where isScrolled was
// incorrectly set to true by layout shifts during page load, etc.
const diff = targetElement.scrollHeight - targetElement.clientHeight;
if (Math.abs(targetElement.scrollTop - diff) <= 10 || diff <= 0) {
window.isScrolled = false;
}
} else {
typing.parentNode.classList.remove("visible-dots");
document.getElementById("stop").style.display = "none";
document.getElementById("Generate").style.display = "flex";
}
doSyntaxHighlighting();
if (!window.isScrolled && !isScrollingClassOnly) {
const maxScroll = targetElement.scrollHeight - targetElement.clientHeight;
if (maxScroll > 0 && targetElement.scrollTop < maxScroll - 1) {
targetElement.scrollTop = maxScroll;
}
}
const chatElement = document.getElementById("chat");
if (chatElement && chatElement.getAttribute("data-mode") === "instruct") {
const messagesContainer = chatElement.querySelector(".messages");
const lastChild = messagesContainer?.lastElementChild;
const prevSibling = lastChild?.previousElementSibling;
if (lastChild && prevSibling) {
// Add padding to the messages container to create room for the last message.
// The purpose of this is to avoid constant scrolling during streaming in
// instruct mode.
let bufferHeight = Math.max(0, Math.max(window.innerHeight - 128 - 84, window.innerHeight - prevSibling.offsetHeight - 84) - lastChild.offsetHeight);
// Subtract header height when screen width is <= 924px
if (window.innerWidth <= 924) {
bufferHeight = Math.max(0, bufferHeight - 32);
}
messagesContainer.style.paddingBottom = `${bufferHeight}px`;
}
}
});
// Configure the observer to watch for changes in the subtree and attributes
// Only watch for attribute changes on targetElement (e.g. _generating class)
const config = {
childList: true,
subtree: true,
characterData: true,
attributeOldValue: true,
characterDataOldValue: true
attributes: true
};
// Start observing the target element
@ -243,64 +222,76 @@ function isElementVisibleOnScreen(element) {
);
}
function doSyntaxHighlighting() {
window.doSyntaxHighlighting = function() {
const messageBodies = document.getElementById("chat").querySelectorAll(".message-body");
if (messageBodies.length > 0) {
observer.disconnect();
let hasSeenVisible = false;
try {
let hasSeenVisible = false;
// Go from last message to first
for (let i = messageBodies.length - 1; i >= 0; i--) {
const messageBody = messageBodies[i];
// Go from last message to first
for (let i = messageBodies.length - 1; i >= 0; i--) {
const messageBody = messageBodies[i];
if (isElementVisibleOnScreen(messageBody)) {
hasSeenVisible = true;
if (isElementVisibleOnScreen(messageBody)) {
hasSeenVisible = true;
// Handle both code and math in a single pass through each message
const codeBlocks = messageBody.querySelectorAll("pre code:not([data-highlighted])");
codeBlocks.forEach((codeBlock) => {
hljs.highlightElement(codeBlock);
codeBlock.setAttribute("data-highlighted", "true");
codeBlock.classList.add("pretty_scrollbar");
});
// Handle both code and math in a single pass through each message
const codeBlocks = messageBody.querySelectorAll("pre code:not([data-highlighted])");
codeBlocks.forEach((codeBlock) => {
hljs.highlightElement(codeBlock);
codeBlock.setAttribute("data-highlighted", "true");
codeBlock.classList.add("pretty_scrollbar");
});
// Only render math in visible elements
const mathContainers = messageBody.querySelectorAll("p, span, li, td, th, h1, h2, h3, h4, h5, h6, blockquote, figcaption, caption, dd, dt");
mathContainers.forEach(container => {
if (isElementVisibleOnScreen(container)) {
renderMathInElement(container, {
delimiters: [
{ left: "$$", right: "$$", display: true },
{ left: "$", right: "$", display: false },
{ left: "\\(", right: "\\)", display: false },
{ left: "\\[", right: "\\]", display: true },
],
});
}
});
} else if (hasSeenVisible) {
// We've seen visible messages but this one is not visible
// Since we're going from last to first, we can break
break;
}
// Only render math in visible elements
const mathContainers = messageBody.querySelectorAll("p, span, li, td, th, h1, h2, h3, h4, h5, h6, blockquote, figcaption, caption, dd, dt");
mathContainers.forEach(container => {
if (isElementVisibleOnScreen(container)) {
renderMathInElement(container, {
delimiters: [
{ left: "$$", right: "$$", display: true },
{ left: "$", right: "$", display: false },
{ left: "\\(", right: "\\)", display: false },
{ left: "\\[", right: "\\]", display: true },
],
});
}
});
} else if (hasSeenVisible) {
// We've seen visible messages but this one is not visible
// Since we're going from last to first, we can break
break;
}
} finally {
observer.observe(targetElement, config);
}
}
}
const doSyntaxHighlighting = window.doSyntaxHighlighting;
//------------------------------------------------
// Add some scrollbars
//------------------------------------------------
const textareaElements = document.querySelectorAll(".add_scrollbar textarea");
for(i = 0; i < textareaElements.length; i++) {
textareaElements[i].classList.remove("scroll-hide");
textareaElements[i].classList.add("pretty_scrollbar");
textareaElements[i].style.resize = "none";
const scrollbarElements = document.querySelectorAll(".add_scrollbar textarea, .add_scrollbar .drag-drop-list");
for(i = 0; i < scrollbarElements.length; i++) {
scrollbarElements[i].classList.remove("scroll-hide");
scrollbarElements[i].classList.add("pretty_scrollbar");
scrollbarElements[i].style.resize = "none";
}
//------------------------------------------------
// Tools: inject "Refresh list" link into the label
//------------------------------------------------
const toolsTitle = document.querySelector("#tools-group > [data-testid='block-info']");
const toolsInfo = toolsTitle ? toolsTitle.nextElementSibling : null;
if (toolsInfo) {
const refreshLink = document.createElement("span");
refreshLink.textContent = " [Refresh list]";
refreshLink.className = "tools-refresh-link";
refreshLink.addEventListener("click", function(e) {
e.preventDefault();
document.querySelector("#tools-refresh-btn").click();
});
toolsInfo.appendChild(refreshLink);
}
//------------------------------------------------
@ -561,6 +552,38 @@ document.querySelectorAll(".focus-on-chat-input").forEach(element => {
});
});
//------------------------------------------------
// "New chat" hover menu with incognito option
//------------------------------------------------
(function() {
const newChatBtn = document.getElementById("new-chat-btn");
const wrapper = document.createElement("div");
wrapper.id = "new-chat-wrapper";
newChatBtn.replaceWith(wrapper);
wrapper.appendChild(newChatBtn);
const arrow = document.createElement("span");
arrow.className = "new-chat-arrow";
arrow.textContent = "\u25BE";
const menu = document.createElement("div");
menu.className = "new-chat-menu";
const option = document.createElement("div");
option.className = "new-chat-menu-item";
option.textContent = "Incognito chat";
menu.appendChild(option);
arrow.appendChild(menu);
wrapper.appendChild(arrow);
option.addEventListener("click", function(e) {
e.stopPropagation();
document.querySelector("#incognito-chat-btn").click();
});
})();
//------------------------------------------------
// Fix a border around the "past chats" menu
//------------------------------------------------
@ -1090,15 +1113,13 @@ document.fonts.addEventListener("loadingdone", (event) => {
const currentHeight = chatInputRow.offsetHeight;
const heightDifference = currentHeight - originalHeight;
chatParent.style.marginBottom = `${originalMarginBottom + heightDifference}px`;
if (!window.isScrolled) {
chatParent.scrollTop = chatParent.scrollHeight - chatParent.clientHeight;
}
}
// Watch for changes that might affect height
const observer = new MutationObserver(updateMargin);
observer.observe(chatInputRow, {
childList: true,
subtree: true,
attributes: true
});
// Watch for size changes that affect height
new ResizeObserver(updateMargin).observe(chatInputRow);
// Also listen for window resize
window.addEventListener("resize", updateMargin);

View file

@ -6,12 +6,13 @@ import json
import pprint
import re
import shutil
import threading
import time
from datetime import datetime
from functools import partial
from pathlib import Path
import gradio as gr
import markupsafe
import yaml
from jinja2.ext import loopcontrols
from jinja2.sandbox import ImmutableSandboxedEnvironment
@ -23,10 +24,12 @@ from modules.extensions import apply_extensions
from modules.html_generator import (
chat_html_wrapper,
convert_to_markdown,
extract_thinking_block,
make_thumbnail
)
from modules.image_utils import open_image_safely
from modules.logging_colors import logger
from modules.reasoning import THINKING_FORMATS
from modules.text_generation import (
generate_reply,
get_encoded_length,
@ -41,6 +44,8 @@ from modules.utils import (
)
from modules.web_search import add_web_search_attachments
_history_file_lock = threading.Lock()
def strftime_now(format):
return datetime.now().strftime(format)
@ -75,8 +80,22 @@ jinja_env = ImmutableSandboxedEnvironment(
lstrip_blocks=True,
extensions=[loopcontrols]
)
def custom_tojson(value, indent=None, ensure_ascii=True):
return markupsafe.Markup(json.dumps(value, indent=indent, ensure_ascii=ensure_ascii))
jinja_env.filters["tojson"] = custom_tojson
jinja_env.globals["strftime_now"] = strftime_now
def _raise_exception(message):
raise ValueError(message)
jinja_env.globals["raise_exception"] = _raise_exception
_template_cache = {}
@ -150,6 +169,49 @@ def _deserialize_tool_call_arguments(tool_calls):
return result
def _expand_tool_sequence(tool_seq):
"""Expand a tool_sequence list into API messages.
Returns a list of dicts (role: assistant with tool_calls, or role: tool).
If any tool_call IDs are missing a matching tool result, a synthetic
empty result is inserted so the prompt is never malformed.
"""
messages = []
expected_ids = []
seen_ids = set()
for item in tool_seq:
if 'tool_calls' in item:
deserialized = _deserialize_tool_call_arguments(item['tool_calls'])
messages.append({
"role": "assistant",
"content": item.get('content', ''),
"tool_calls": deserialized
})
for tc in item['tool_calls']:
tc_id = tc.get('id', '')
if tc_id:
expected_ids.append(tc_id)
elif item.get('role') == 'tool':
messages.append({
"role": "tool",
"content": item['content'],
"tool_call_id": item.get('tool_call_id', '')
})
seen_ids.add(item.get('tool_call_id', ''))
# Fill in synthetic results for any orphaned tool call IDs
for tc_id in expected_ids:
if tc_id not in seen_ids:
messages.append({
"role": "tool",
"content": "",
"tool_call_id": tc_id
})
return messages
def generate_chat_prompt(user_input, state, **kwargs):
impersonate = kwargs.get('impersonate', False)
_continue = kwargs.get('_continue', False)
@ -185,6 +247,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
name1=state['name1'],
name2=state['name2'],
user_bio=replace_character_names(state['user_bio'], state['name1'], state['name2']),
tools=state['tools'] if 'tools' in state else None,
)
messages = []
@ -296,7 +359,18 @@ def generate_chat_prompt(user_input, state, **kwargs):
if entry_meta.get('tool_calls') and messages[insert_pos].get('role') == 'assistant':
messages[insert_pos]['tool_calls'] = _deserialize_tool_call_arguments(entry_meta['tool_calls'])
if user_msg not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
# Expand tool_sequence from metadata (inserted AFTER assistant so that
# the final order is: user → tool_calls → tool_results → final_answer)
meta_key = f"assistant_{row_idx}"
tool_seq = metadata.get(meta_key, {}).get('tool_sequence', [])
if tool_seq:
for msg in reversed(_expand_tool_sequence(tool_seq)):
messages.insert(insert_pos, msg)
if entry_meta.get('role') == 'system':
if user_msg:
messages.insert(insert_pos, {"role": "system", "content": user_msg})
elif user_msg not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
# Check for user message attachments in metadata
user_key = f"user_{row_idx}"
enhanced_user_msg = user_msg
@ -365,6 +439,12 @@ def generate_chat_prompt(user_input, state, **kwargs):
messages.append({"role": "user", "content": user_input})
# Expand tool_sequence for the current entry (excluded from the
# history loop during regenerate — needed so the model sees prior
# tool calls and results when re-generating the final answer).
current_tool_seq = metadata.get(f"assistant_{len(history)}", {}).get('tool_sequence', [])
messages.extend(_expand_tool_sequence(current_tool_seq))
if impersonate and state['mode'] != 'chat-instruct':
messages.append({"role": "user", "content": "fake user message replace me"})
@ -810,6 +890,8 @@ def generate_search_query(user_message, state):
query = query.rsplit("</think>", 1)[1]
elif "<|start|>assistant<|channel|>final<|message|>" in query:
query = query.rsplit("<|start|>assistant<|channel|>final<|message|>", 1)[1]
elif "<|channel|>final<|message|>" in query:
query = query.rsplit("<|channel|>final<|message|>", 1)[1]
elif "</seed:think>" in query:
query = query.rsplit("</seed:think>", 1)[1]
@ -884,7 +966,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
}
else:
text, visible_text = output['internal'][-1][0], output['visible'][-1][0]
if regenerate:
if regenerate and not state.get('_tool_turn'):
row_idx = len(output['internal']) - 1
# Store the old response as a version before regenerating
@ -947,10 +1029,33 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
# Add timestamp for assistant's response at the start of generation
update_message_metadata(output['metadata'], "assistant", row_idx, timestamp=get_current_timestamp(), model_name=shared.model_name)
# Detect if the template appended a thinking start tag to the prompt
thinking_prefix = None
if not _continue:
stripped_prompt = prompt.rstrip('\n')
for start_tag, end_tag, content_tag in THINKING_FORMATS:
if start_tag is not None and stripped_prompt.endswith(start_tag):
thinking_prefix = start_tag
break
# When tools are active, buffer streaming output during potential tool
# call generation to prevent raw markup from leaking into the display.
_check_tool_markers = bool(state.get('tools'))
_last_visible_before_tool_buffer = None
if _check_tool_markers:
from modules.tool_parsing import streaming_tool_buffer_check, detect_tool_call_format
_tool_names = [t['function']['name'] for t in state['tools'] if 'function' in t and 'name' in t['function']]
_template_str = state.get('instruction_template_str', '') if state.get('mode') == 'instruct' else state.get('chat_template_str', '')
_, _streaming_markers, _check_bare_names = detect_tool_call_format(_template_str)
# Generate
reply = None
for j, reply in enumerate(generate_reply(prompt, state, stopping_strings=stopping_strings, is_chat=True, for_ui=for_ui)):
# Prepend thinking tag if the template appended it to the prompt
if thinking_prefix:
reply = thinking_prefix + reply
# Extract the reply
if state['mode'] in ['chat', 'chat-instruct']:
if not _continue:
@ -982,7 +1087,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')]
# Keep version metadata in sync during streaming (for regeneration)
if regenerate:
if regenerate and not state.get('_tool_turn'):
row_idx = len(output['internal']) - 1
key = f"assistant_{row_idx}"
current_idx = output['metadata'][key]['current_version_index']
@ -992,25 +1097,35 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
})
if is_stream:
if _check_tool_markers:
if streaming_tool_buffer_check(output['internal'][-1][1], markers=_streaming_markers, tool_names=_tool_names, check_bare_names=_check_bare_names):
continue
_last_visible_before_tool_buffer = output['visible'][-1][1]
yield output
if _continue:
# Reprocess the entire internal text for extensions (like translation)
full_internal = output['internal'][-1][1]
if state['mode'] in ['chat', 'chat-instruct']:
full_visible = re.sub("(<USER>|<user>|{{user}})", state['name1'], full_internal)
else:
full_visible = full_internal
# Reprocess the entire internal text for extensions (like translation).
# Skip entirely when the visible text contains <tool_call> markers,
# since those only exist in visible (internal is cleared after each tool
# execution) and rebuilding from internal would destroy them. Output
# extensions also can't handle the raw <tool_call> markup safely.
if '<tool_call>' not in output['visible'][-1][1]:
full_internal = output['internal'][-1][1]
if state['mode'] in ['chat', 'chat-instruct']:
full_visible = re.sub("(<USER>|<user>|{{user}})", state['name1'], full_internal)
else:
full_visible = full_internal
full_visible = html.escape(full_visible)
if not state.get('_skip_output_extensions'):
output['visible'][-1][1] = apply_extensions('output', full_visible, state, is_chat=True)
full_visible = html.escape(full_visible)
if not state.get('_skip_output_extensions'):
output['visible'][-1][1] = apply_extensions('output', full_visible, state, is_chat=True)
else:
if not state.get('_skip_output_extensions'):
output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state, is_chat=True)
# Final sync for version metadata (in case streaming was disabled)
if regenerate:
if regenerate and not state.get('_tool_turn'):
row_idx = len(output['internal']) - 1
key = f"assistant_{row_idx}"
current_idx = output['metadata'][key]['current_version_index']
@ -1019,6 +1134,13 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
'visible_content': output['visible'][row_idx][1]
})
# When tool markers were detected during streaming, restore the last
# visible text from before buffering started so raw markup doesn't flash
# in the UI. The internal text is left intact so the caller can still
# parse tool calls from it.
if is_stream and _check_tool_markers and streaming_tool_buffer_check(output['internal'][-1][1], markers=_streaming_markers, tool_names=_tool_names, check_bare_names=_check_bare_names):
output['visible'][-1][1] = _last_visible_before_tool_buffer or ''
yield output
@ -1064,7 +1186,11 @@ def character_is_loaded(state, raise_exception=False):
def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
'''
Same as above but returns HTML for the UI
Same as above but returns HTML for the UI.
When tools are selected, wraps generation in a loop that detects
tool calls, executes them, and re-generates until the model stops.
All tool output is consolidated into a single visible chat bubble
using metadata['assistant_N']['tool_sequence'].
'''
if not character_is_loaded(state):
@ -1079,19 +1205,257 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
send_dummy_message(text, state)
send_dummy_reply(state['start_with'], state)
history = state['history']
# On regenerate, clear old tool_sequence metadata so it gets rebuilt.
# Save it first so it can be stored per-version below.
# This must happen after the start_with logic above, which may remove
# and re-add messages, changing which row we operate on.
_old_tool_sequence = None
if regenerate:
history = state['history']
meta = history.get('metadata', {})
row_idx = len(history['internal']) - 1
if row_idx >= 0:
_old_tool_sequence = meta.get(f'assistant_{row_idx}', {}).pop('tool_sequence', None)
# Load tools if any are selected
selected = state.get('selected_tools', [])
parse_tool_call = None
_tool_parsers = None
if selected:
from modules.tool_use import load_tools, execute_tool
from modules.tool_parsing import parse_tool_call, get_tool_call_id, detect_tool_call_format
if selected:
tool_defs, tool_executors = load_tools(selected)
state['tools'] = tool_defs
tool_func_names = [t['function']['name'] for t in tool_defs]
_template_str = state.get('instruction_template_str', '') if state.get('mode') == 'instruct' else state.get('chat_template_str', '')
_tool_parsers, _, _ = detect_tool_call_format(_template_str)
else:
tool_func_names = None
visible_prefix = [] # Accumulated tool call summaries + results
last_save_time = time.monotonic()
save_interval = 8
for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True, for_ui=True)):
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'], last_message_only=(i > 0)), history
if i == 0:
time.sleep(0.125) # We need this to make sure the first update goes through
_tool_turn = 0
while True:
history = state['history']
current_time = time.monotonic()
# Save on first iteration or if save_interval seconds have passed
if i == 0 or (current_time - last_save_time) >= save_interval:
# Turn 0: use original flags; turns 2+: regenerate into the same entry.
# _tool_turn tells chatbot_wrapper to skip version creation/sync so
# that intermediate tool-loop regenerations don't pollute swipe history.
if _tool_turn > 0:
state['_tool_turn'] = True
state['_skip_output_extensions'] = True
regen = regenerate if _tool_turn == 0 else True
cont = _continue if _tool_turn == 0 else False
cur_text = text if _tool_turn == 0 else ''
for i, history in enumerate(generate_chat_reply(cur_text, state, regen, cont, loading_message=True, for_ui=True)):
# Prepend accumulated tool output to visible reply for display.
# Save and restore the original to prevent the markers from leaking
# back into chatbot_wrapper's shared output object, which would cause
# duplication on the next yield.
_original_visible = history['visible'][-1][1] if visible_prefix else None
if visible_prefix:
history['visible'][-1][1] = '\n\n'.join(visible_prefix + [_original_visible])
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'], last_message_only=(i > 0)), history
if visible_prefix:
history['visible'][-1][1] = _original_visible
if i == 0:
# Save old tool_sequence into version 0 (created by chatbot_wrapper
# on the first yield). Only needed on the first regeneration when
# versions didn't previously exist.
if _old_tool_sequence is not None and _tool_turn == 0:
_ri = len(history['internal']) - 1
_versions = history.get('metadata', {}).get(f'assistant_{_ri}', {}).get('versions', [])
if _versions and 'tool_sequence' not in _versions[0]:
_versions[0]['tool_sequence'] = _old_tool_sequence
_old_tool_sequence = None
time.sleep(0.125)
current_time = time.monotonic()
if i == 0 or (current_time - last_save_time) >= save_interval:
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
last_save_time = current_time
# Early stop on tool call detection
if tool_func_names and parse_tool_call(history['internal'][-1][1], tool_func_names, parsers=_tool_parsers):
break
# Save the model's visible output before re-applying visible_prefix,
# so we can extract thinking content from just this turn's output.
_model_visible = history['visible'][-1][1]
# Recover visible_prefix from existing visible text (e.g. on Continue
# after a previous session had tool calls). Extract all <tool_call>
# blocks and any text between them (thinking blocks, intermediate text).
if tool_func_names and not visible_prefix and _model_visible:
tc_matches = list(re.finditer(r'<tool_call>.*?</tool_call>', _model_visible, re.DOTALL))
if tc_matches:
prefix_end = tc_matches[-1].end()
prefix = _model_visible[:prefix_end].strip()
if prefix:
visible_prefix = [prefix]
_model_visible = _model_visible[prefix_end:].strip()
# Re-apply visible prefix to the final state after streaming completes.
# This is safe because we're no longer sharing the object with chatbot_wrapper.
if visible_prefix:
history['visible'][-1][1] = '\n\n'.join(visible_prefix + [_model_visible])
if tool_func_names:
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
last_save_time = current_time
# Check for tool calls
if not tool_func_names or shared.stop_everything:
break
answer = history['internal'][-1][1]
parsed_calls, content_prefix = parse_tool_call(answer, tool_func_names, return_prefix=True, parsers=_tool_parsers) if answer else (None, '')
if not parsed_calls:
break # No tool calls — done
# --- Process tool calls ---
row_idx = len(history['internal']) - 1
meta = history.get('metadata', {})
seq = meta.setdefault(f'assistant_{row_idx}', {}).setdefault('tool_sequence', [])
def _render():
return chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
# Serialize tool calls and build display headers in one pass
serialized = []
tc_headers = []
for tc in parsed_calls:
tc['id'] = get_tool_call_id()
fn_name = tc['function']['name']
fn_args = tc['function'].get('arguments', {})
serialized.append({
'id': tc['id'],
'type': 'function',
'function': {
'name': fn_name,
'arguments': json.dumps(fn_args) if isinstance(fn_args, dict) else fn_args
}
})
if isinstance(fn_args, dict) and fn_args:
args_summary = ', '.join(f'{k}={json.dumps(v, ensure_ascii=False)}' for k, v in fn_args.items())
elif isinstance(fn_args, dict):
args_summary = ''
else:
args_summary = str(fn_args)
tc_headers.append(f'{fn_name}({args_summary})')
seq_entry = {'tool_calls': serialized}
if content_prefix.strip():
# Strip GPT-OSS channel tokens so they don't get double-wrapped
# by the template (which adds its own channel markup).
clean = content_prefix.strip()
if '<|channel|>' in clean and '<|message|>' in clean:
inner = clean.split('<|message|>', 1)[1]
if '<|end|>' in inner:
inner = inner.split('<|end|>', 1)[0]
clean = inner.strip()
if clean:
seq_entry['content'] = clean
seq.append(seq_entry)
# Clear internal (raw tool markup)
history['internal'][-1][1] = ''
# Preserve thinking block and intermediate text from this turn.
# content_prefix is the raw text before tool call syntax (returned
# by parse_tool_call); HTML-escape it and extract thinking to get
# the content the user should see.
content_text = html.escape(content_prefix)
thinking_content, intermediate = extract_thinking_block(content_text)
if thinking_content:
visible_prefix.append(f'&lt;think&gt;\n{thinking_content}\n&lt;/think&gt;')
if intermediate and intermediate.strip():
visible_prefix.append(intermediate.strip())
# Show placeholder accordions with "..." before execution starts
# (tool calls may be slow, e.g. web search).
pending_placeholders = [f'<tool_call>{h}\n...\n</tool_call>' for h in tc_headers]
history['visible'][-1][1] = '\n\n'.join(visible_prefix + pending_placeholders)
yield _render(), history
# Execute tools, store results, and replace placeholders with real results
for i, tc in enumerate(parsed_calls):
# Check for stop request before each tool execution
if shared.stop_everything:
for j in range(i, len(parsed_calls)):
seq.append({'role': 'tool', 'content': 'Tool execution was cancelled by the user.', 'tool_call_id': parsed_calls[j]['id']})
pending_placeholders[j] = f'<tool_call>{tc_headers[j]}\nCancelled\n</tool_call>'
history['visible'][-1][1] = '\n\n'.join(visible_prefix + pending_placeholders)
yield _render(), history
break
fn_name = tc['function']['name']
fn_args = tc['function'].get('arguments', {})
result = execute_tool(fn_name, fn_args, tool_executors)
seq.append({'role': 'tool', 'content': result, 'tool_call_id': tc['id']})
try:
pretty_result = json.dumps(json.loads(result), indent=2, ensure_ascii=False)
except (json.JSONDecodeError, TypeError):
pretty_result = result
# Replace the placeholder with the real result
pending_placeholders[i] = f'<tool_call>{tc_headers[i]}\n{pretty_result}\n</tool_call>'
history['visible'][-1][1] = '\n\n'.join(visible_prefix + pending_placeholders)
yield _render(), history
# Move completed tool calls into visible_prefix for next turns
visible_prefix.extend(pending_placeholders)
history['visible'][-1][1] = '\n\n'.join(visible_prefix)
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
state['history'] = history
_tool_turn += 1
state.pop('_tool_turn', None)
# If output extensions were deferred during tool turns, apply them now
# to the final model response only (not to tool call markers).
if state.pop('_skip_output_extensions', None):
_model_visible = apply_extensions('output', _model_visible, state, is_chat=True)
if visible_prefix:
history['visible'][-1][1] = '\n\n'.join(visible_prefix + [_model_visible])
else:
history['visible'][-1][1] = _model_visible
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']), history
state['history'] = history
# Sync version metadata so swipes show the full visible (with tool prefix)
if visible_prefix and history.get('metadata'):
row_idx = len(history['internal']) - 1
key = f"assistant_{row_idx}"
meta_entry = history['metadata'].get(key, {})
if 'versions' in meta_entry and 'current_version_index' in meta_entry:
current_idx = meta_entry['current_version_index']
if current_idx < len(meta_entry['versions']):
version_update = {
'content': history['internal'][row_idx][1],
'visible_content': history['visible'][row_idx][1]
}
ts = meta_entry.get('tool_sequence')
if ts is not None:
version_update['tool_sequence'] = ts
meta_entry['versions'][current_idx].update(version_update)
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
@ -1164,7 +1528,7 @@ def redraw_html(history, name1, name2, mode, style, character, reset_cache=False
return chat_html_wrapper(history, name1, name2, mode, style, character, reset_cache=reset_cache)
def start_new_chat(state):
def start_new_chat(state, unique_id=None):
mode = state['mode']
# Initialize with empty metadata dictionary
history = {'internal': [], 'visible': [], 'metadata': {}}
@ -1178,7 +1542,9 @@ def start_new_chat(state):
# Add timestamp for assistant's greeting
update_message_metadata(history['metadata'], "assistant", 0, timestamp=get_current_timestamp())
unique_id = datetime.now().strftime('%Y%m%d-%H-%M-%S')
if unique_id is None:
unique_id = datetime.now().strftime('%Y%m%d-%H-%M-%S')
save_history(history, unique_id, state['character_menu'], state['mode'])
return history
@ -1197,12 +1563,16 @@ def save_history(history, unique_id, character, mode):
if shared.args.multi_user:
return
if unique_id and unique_id.startswith('incognito-'):
return
p = get_history_file_path(unique_id, character, mode)
if not p.parent.is_dir():
p.parent.mkdir(parents=True)
with open(p, 'w', encoding='utf-8') as f:
f.write(json.dumps(history, indent=4, ensure_ascii=False))
with _history_file_lock:
with open(p, 'w', encoding='utf-8') as f:
f.write(json.dumps(history, indent=4, ensure_ascii=False))
def rename_history(old_id, new_id, character, mode):
@ -1333,6 +1703,7 @@ def load_history_after_deletion(state, idx):
Loads the latest history for the given character in chat or chat-instruct
mode, or the latest instruct history for instruct mode.
'''
import gradio as gr
if shared.args.multi_user:
return start_new_chat(state)
@ -1351,6 +1722,7 @@ def load_history_after_deletion(state, idx):
def update_character_menu_after_deletion(idx):
import gradio as gr
characters = utils.get_available_characters()
idx = min(int(idx), len(characters) - 1)
idx = max(0, idx)
@ -1383,6 +1755,9 @@ def save_last_chat_state(character, mode, unique_id):
if shared.args.multi_user:
return
if unique_id and unique_id.startswith('incognito-'):
return
state = load_last_chat_state()
key = get_chat_state_key(character, mode)
state["last_chats"][key] = unique_id
@ -1565,24 +1940,6 @@ def clear_character_for_ui(state):
return state, state['name2'], state['context'], state['greeting'], None
def load_instruction_template(template):
if template == 'None':
return ''
for filepath in [shared.user_data_dir / 'instruction-templates' / f'{template}.yaml', shared.user_data_dir / 'instruction-templates' / 'Alpaca.yaml']:
if filepath.exists():
break
else:
return ''
file_contents = open(filepath, 'r', encoding='utf-8').read()
data = yaml.safe_load(file_contents)
if 'instruction_template' in data:
return data['instruction_template']
else:
return jinja_template_from_old_format(data)
@functools.cache
def load_character_memoized(character, name1, name2):
return load_character(character, name1, name2)
@ -1590,10 +1947,12 @@ def load_character_memoized(character, name1, name2):
@functools.cache
def load_instruction_template_memoized(template):
from modules.models_settings import load_instruction_template
return load_instruction_template(template)
def upload_character(file, img_path, tavern=False):
import gradio as gr
img = open_image_safely(img_path)
decoded_file = file if isinstance(file, str) else file.decode('utf-8')
try:
@ -1647,6 +2006,7 @@ def upload_tavern_character(img_path, _json):
def check_tavern_character(img_path):
import gradio as gr
img = open_image_safely(img_path)
if img is None:
@ -1832,6 +2192,7 @@ def delete_user(name):
def update_user_menu_after_deletion(idx):
"""Update user menu after a user is deleted"""
import gradio as gr
users = get_available_users()
if len(users) == 0:
# Create a default user if none exist
@ -1864,93 +2225,13 @@ def handle_user_menu_change(state):
def handle_save_user_click(name1):
"""Handle save user button click"""
import gradio as gr
return [
name1,
gr.update(visible=True)
]
def jinja_template_from_old_format(params, verbose=False):
MASTER_TEMPLATE = """
{%- set ns = namespace(found=false) -%}
{%- for message in messages -%}
{%- if message['role'] == 'system' -%}
{%- set ns.found = true -%}
{%- endif -%}
{%- endfor -%}
{%- if not ns.found -%}
{{- '<|PRE-SYSTEM|>' + '<|SYSTEM-MESSAGE|>' + '<|POST-SYSTEM|>' -}}
{%- endif %}
{%- for message in messages %}
{%- if message['role'] == 'system' -%}
{{- '<|PRE-SYSTEM|>' + message['content'] + '<|POST-SYSTEM|>' -}}
{%- else -%}
{%- if message['role'] == 'user' -%}
{{-'<|PRE-USER|>' + message['content'] + '<|POST-USER|>'-}}
{%- else -%}
{{-'<|PRE-ASSISTANT|>' + message['content'] + '<|POST-ASSISTANT|>' -}}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{-'<|PRE-ASSISTANT-GENERATE|>'-}}
{%- endif -%}
"""
if 'context' in params and '<|system-message|>' in params['context']:
pre_system = params['context'].split('<|system-message|>')[0]
post_system = params['context'].split('<|system-message|>')[1]
else:
pre_system = ''
post_system = ''
pre_user = params['turn_template'].split('<|user-message|>')[0].replace('<|user|>', params['user'])
post_user = params['turn_template'].split('<|user-message|>')[1].split('<|bot|>')[0]
pre_assistant = '<|bot|>' + params['turn_template'].split('<|bot-message|>')[0].split('<|bot|>')[1]
pre_assistant = pre_assistant.replace('<|bot|>', params['bot'])
post_assistant = params['turn_template'].split('<|bot-message|>')[1]
def preprocess(string):
return string.replace('\n', '\\n').replace('\'', '\\\'')
pre_system = preprocess(pre_system)
post_system = preprocess(post_system)
pre_user = preprocess(pre_user)
post_user = preprocess(post_user)
pre_assistant = preprocess(pre_assistant)
post_assistant = preprocess(post_assistant)
if verbose:
print(
'\n',
repr(pre_system) + '\n',
repr(post_system) + '\n',
repr(pre_user) + '\n',
repr(post_user) + '\n',
repr(pre_assistant) + '\n',
repr(post_assistant) + '\n',
)
result = MASTER_TEMPLATE
if 'system_message' in params:
result = result.replace('<|SYSTEM-MESSAGE|>', preprocess(params['system_message']))
else:
result = result.replace('<|SYSTEM-MESSAGE|>', '')
result = result.replace('<|PRE-SYSTEM|>', pre_system)
result = result.replace('<|POST-SYSTEM|>', post_system)
result = result.replace('<|PRE-USER|>', pre_user)
result = result.replace('<|POST-USER|>', post_user)
result = result.replace('<|PRE-ASSISTANT|>', pre_assistant)
result = result.replace('<|PRE-ASSISTANT-GENERATE|>', pre_assistant.rstrip(' '))
result = result.replace('<|POST-ASSISTANT|>', post_assistant)
result = result.strip()
return result
def my_yaml_output(data):
'''
pyyaml is very inconsistent with multiline strings.
@ -2002,6 +2283,7 @@ def handle_unique_id_select(state):
def handle_start_new_chat_click(state):
import gradio as gr
history = start_new_chat(state)
histories = find_all_histories_with_first_prompts(state)
html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
@ -2016,10 +2298,29 @@ def handle_start_new_chat_click(state):
return [history, html, past_chats_update]
def handle_start_incognito_chat_click(state):
import gradio as gr
unique_id = 'incognito-' + datetime.now().strftime('%Y%m%d-%H-%M-%S')
history = start_new_chat(state, unique_id=unique_id)
html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
convert_to_markdown.cache_clear()
histories = find_all_histories_with_first_prompts(state)
past_chats_update = gr.update(choices=histories, value=unique_id)
return [history, html, past_chats_update]
def handle_delete_chat_confirm_click(state):
filtered_histories = find_all_histories_with_first_prompts(state)
filtered_ids = [h[1] for h in filtered_histories]
index = str(filtered_ids.index(state['unique_id']))
if state['unique_id'] not in filtered_ids:
# Incognito or unknown chat — just load the most recent saved chat
index = '0'
else:
index = str(filtered_ids.index(state['unique_id']))
delete_history(state['unique_id'], state['character_menu'], state['mode'])
history, unique_id = load_history_after_deletion(state, index)
@ -2027,16 +2328,11 @@ def handle_delete_chat_confirm_click(state):
convert_to_markdown.cache_clear()
return [
history,
html,
unique_id,
gr.update(visible=False),
gr.update(visible=True),
]
return [history, html, unique_id]
def handle_branch_chat_click(state):
import gradio as gr
branch_from_index = state['branch_index']
if branch_from_index == -1:
history = state['history']
@ -2048,7 +2344,8 @@ def handle_branch_chat_click(state):
if 'metadata' in history:
history['metadata'] = {k: v for k, v in history['metadata'].items() if int(k.split('_')[-1]) <= branch_from_index}
new_unique_id = datetime.now().strftime('%Y%m%d-%H-%M-%S')
prefix = 'incognito-' if state['unique_id'] and state['unique_id'].startswith('incognito-') else ''
new_unique_id = prefix + datetime.now().strftime('%Y%m%d-%H-%M-%S')
save_history(history, new_unique_id, state['character_menu'], state['mode'])
histories = find_all_histories_with_first_prompts(state)
@ -2086,14 +2383,19 @@ def handle_edit_message_click(state):
original_visible = history['visible'][message_index][role_idx]
original_timestamp = history['metadata'][key].get('timestamp', get_current_timestamp())
history['metadata'][key]["versions"] = [{
version_entry = {
"content": original_content,
"visible_content": original_visible,
"timestamp": original_timestamp
}]
}
ts = history['metadata'][key].get('tool_sequence')
if ts is not None:
version_entry['tool_sequence'] = ts
history['metadata'][key]["versions"] = [version_entry]
history['internal'][message_index][role_idx] = apply_extensions('input', new_text, state, is_chat=True)
history['visible'][message_index][role_idx] = html.escape(new_text)
history['metadata'][key].pop('tool_sequence', None)
add_message_version(history, role, message_index, is_current=True)
@ -2138,6 +2440,14 @@ def handle_navigate_version_click(state):
history['internal'][message_index][msg_content_idx] = version_to_load['content']
history['visible'][message_index][msg_content_idx] = version_to_load['visible_content']
metadata['current_version_index'] = new_idx
# Restore per-version tool_sequence so follow-up prompts see consistent context
version_ts = version_to_load.get('tool_sequence')
if version_ts is not None:
metadata['tool_sequence'] = version_ts
else:
metadata.pop('tool_sequence', None)
update_message_metadata(history['metadata'], role, message_index, timestamp=version_to_load['timestamp'])
# Redraw and save
@ -2148,6 +2458,7 @@ def handle_navigate_version_click(state):
def handle_rename_chat_click():
import gradio as gr
return [
gr.update(value="My New Chat"),
gr.update(visible=True),
@ -2155,6 +2466,14 @@ def handle_rename_chat_click():
def handle_rename_chat_confirm(rename_to, state):
import gradio as gr
if state['unique_id'] and state['unique_id'].startswith('incognito-'):
return [
gr.update(),
gr.update(visible=False),
]
rename_history(state['unique_id'], rename_to, state['character_menu'], state['mode'])
histories = find_all_histories_with_first_prompts(state)
@ -2165,11 +2484,13 @@ def handle_rename_chat_confirm(rename_to, state):
def handle_search_chat_change(state):
import gradio as gr
histories = find_all_histories_with_first_prompts(state)
return gr.update(choices=histories)
def handle_upload_chat_history(load_chat_history, state):
import gradio as gr
history = start_new_chat(state)
history = load_history_json(load_chat_history, history)
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
@ -2192,6 +2513,7 @@ def handle_upload_chat_history(load_chat_history, state):
def handle_character_menu_change(state):
import gradio as gr
name1, name2, picture, greeting, context = load_character(state['character_menu'], state['name1'], state['name2'])
state['name1'] = name1
@ -2244,6 +2566,7 @@ def handle_character_picture_change(picture_path):
def handle_mode_change(state):
import gradio as gr
history, loaded_unique_id = load_latest_history(state)
histories = find_all_histories_with_first_prompts(state)
@ -2270,6 +2593,7 @@ def handle_mode_change(state):
def handle_save_character_click(name2):
import gradio as gr
return [
name2,
gr.update(visible=True)
@ -2277,6 +2601,7 @@ def handle_save_character_click(name2):
def handle_load_template_click(instruction_template):
from modules.models_settings import load_instruction_template
output = load_instruction_template(instruction_template)
return [
output,
@ -2285,6 +2610,7 @@ def handle_load_template_click(instruction_template):
def handle_save_template_click(instruction_template_str):
import gradio as gr
contents = generate_instruction_template_yaml(instruction_template_str)
return [
"My Template.yaml",
@ -2295,6 +2621,7 @@ def handle_save_template_click(instruction_template_str):
def handle_delete_template_click(template):
import gradio as gr
return [
f"{template}.yaml",
str(shared.user_data_dir / 'instruction-templates') + '/',
@ -2310,6 +2637,7 @@ def handle_your_picture_change(picture, state):
def handle_send_instruction_click(state):
import gradio as gr
state['mode'] = 'instruct'
state['history'] = {'internal': [], 'visible': [], 'metadata': {}}
@ -2322,6 +2650,7 @@ def handle_send_instruction_click(state):
def handle_send_chat_click(state):
import gradio as gr
output = generate_chat_prompt("", state, _continue=True)
if state["show_two_notebook_columns"]:

View file

@ -1,3 +1,4 @@
import math
import queue
import threading
import traceback
@ -9,6 +10,7 @@ import torch
from exllamav3 import Cache, Config, Generator, Model, Tokenizer
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
from exllamav3.generator import Job
from exllamav3.generator.filter import Filter
from exllamav3.generator.sampler import (
CustomSampler,
SS_AdaptiveP,
@ -36,6 +38,29 @@ except Exception:
traceback.print_exc()
class LogitBiasFilter(Filter):
"""Filter subclass that applies a static additive logit bias mask."""
def __init__(self, tokenizer, logit_bias_dict):
super().__init__(tokenizer=tokenizer, trigger_token=None, prefix_str=None, eos_after_completed=False)
self.logit_bias_dict = logit_bias_dict
self._mask = None
def reset(self): pass
def accept_token(self, token): pass
def is_completed(self): return False
def use_background_worker(self): return False
def get_next_logit_mask(self):
if self._mask is None:
self._mask = torch.zeros((1, self.vocab_size), dtype=self.logits_dtype)
for token_id_str, bias in self.logit_bias_dict.items():
token_id = int(token_id_str)
if 0 <= token_id < self.vocab_size:
self._mask[0, token_id] = bias
return self._mask
class ConcurrentGenerator:
def __init__(self, generator):
self.generator = generator
@ -53,7 +78,16 @@ class ConcurrentGenerator:
if not self.job_queues:
self.has_jobs.clear()
continue
results = self.generator.iterate()
try:
results = self.generator.iterate()
except Exception:
logger.error("Exception in ConcurrentGenerator iterate loop:\n" + traceback.format_exc())
for q in self.job_queues.values():
q.put(None)
self.job_queues.clear()
self.generator.clear_queue()
self.has_jobs.clear()
continue
for result in results:
job = result["job"]
q = self.job_queues.get(job)
@ -89,6 +123,10 @@ class Exllamav3Model:
def __init__(self):
pass
@property
def device(self) -> torch.device:
return torch.device(0)
@classmethod
def from_pretrained(cls, path_to_model):
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
@ -149,8 +187,21 @@ class Exllamav3Model:
load_params['tensor_p'] = True
load_params['tp_backend'] = shared.args.tp_backend
model.load(**load_params)
tokenizer = Tokenizer.from_config(config)
# Load vision and draft before the main model so autosplit
# accounts for their VRAM usage.
# Load vision model component (ExLlamaV3 native)
vision_model = None
if "vision_config" in config.config_dict:
logger.info("Vision component detected in model config. Attempting to load...")
try:
vision_model = Model.from_config(config, component="vision")
vision_model.load(progressbar=True)
logger.info("Vision model loaded successfully.")
except Exception as e:
logger.warning(f"Vision model loading failed (multimodal disabled): {e}")
else:
logger.info("No vision component in model config. Skipping multimodal setup.")
# Initialize draft model for speculative decoding
draft_model = None
@ -166,23 +217,8 @@ class Exllamav3Model:
logger.warning(f"Draft model not found at {draft_path}, speculative decoding disabled.")
else:
draft_config = Config.from_directory(str(draft_path))
# Set context size for draft model with 256-multiple validation
if shared.args.ctx_size_draft > 0:
draft_max_tokens = shared.args.ctx_size_draft
else:
draft_max_tokens = shared.args.ctx_size
# Validate draft model context size is a multiple of 256
if draft_max_tokens % 256 != 0:
adjusted_draft_tokens = ((draft_max_tokens // 256) + 1) * 256
logger.warning(f"Draft model max_num_tokens must be a multiple of 256. Adjusting from {draft_max_tokens} to {adjusted_draft_tokens}")
draft_max_tokens = adjusted_draft_tokens
draft_config.max_seq_len = draft_max_tokens
draft_model = Model.from_config(draft_config)
draft_cache = Cache(draft_model, max_num_tokens=draft_max_tokens, layer_type=layer_type, **cache_kwargs)
draft_cache = Cache(draft_model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs)
draft_load_params = {'progressbar': True}
if split:
@ -191,18 +227,9 @@ class Exllamav3Model:
draft_model.load(**draft_load_params)
logger.info(f"Draft model loaded successfully. Max speculative tokens: {shared.args.draft_max}")
# Load vision model component (ExLlamaV3 native)
vision_model = None
if "vision_config" in config.config_dict:
logger.info("Vision component detected in model config. Attempting to load...")
try:
vision_model = Model.from_config(config, component="vision")
vision_model.load(progressbar=True)
logger.info("Vision model loaded successfully.")
except Exception as e:
logger.warning(f"Vision model loading failed (multimodal disabled): {e}")
else:
logger.info("No vision component in model config. Skipping multimodal setup.")
# Load main model last
model.load(**load_params)
tokenizer = Tokenizer.from_config(config)
generator = Generator(
model=model,
@ -385,11 +412,22 @@ class Exllamav3Model:
else:
max_new_tokens = state['max_new_tokens']
# Get stop conditions
# Use full EOS token list from config (may contain multiple IDs)
stop_conditions = []
if not state['ban_eos_token']:
if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
stop_conditions.append(self.tokenizer.eos_token_id)
for eos_id in self.config.eos_token_id_list:
if eos_id is not None:
stop_conditions.append(eos_id)
# Build filters for logit_bias (OpenAI API)
filters = []
logit_bias = state.get('logit_bias')
if logit_bias:
filters.append(LogitBiasFilter(self.tokenizer, logit_bias))
# Logprobs support (OpenAI API)
logprobs = state.get('logprobs', 0) or 0
return_top_tokens = logprobs if logprobs > 0 else 0
seed = state.get('seed', -1)
job = Job(
@ -400,11 +438,15 @@ class Exllamav3Model:
sampler=sampler,
seed=seed if seed >= 0 else None,
stop_conditions=stop_conditions if stop_conditions else None,
filters=filters if filters else None,
return_top_tokens=return_top_tokens,
return_probs=return_top_tokens > 0,
)
# Stream generation
response_text = ""
stop_event = state.get('stop_event')
self.last_completion_probabilities = []
result_queue = self.parallel_generator.submit(job)
try:
@ -416,14 +458,41 @@ class Exllamav3Model:
except queue.Empty:
continue
if result is None or result.get("eos"):
# Capture logprobs from the final eos result too
if result is not None and return_top_tokens > 0:
self._capture_logprobs(result)
break
chunk = result.get("text", "")
# Capture logprobs from streaming results
if return_top_tokens > 0:
self._capture_logprobs(result)
if chunk:
response_text += chunk
yield response_text
finally:
self.parallel_generator.cancel(job)
def _capture_logprobs(self, result):
"""Convert ExLlamav3 top-k token data to the shared logprobs format."""
top_k_tokens = result.get("top_k_tokens")
top_k_probs = result.get("top_k_probs")
if top_k_tokens is None or top_k_probs is None:
return
id_to_piece = self.tokenizer.get_id_to_piece_list(True)
# top_k_tokens shape: (batch, seq_len, k), top_k_probs same
for seq_idx in range(top_k_tokens.shape[1]):
entry = {"top_logprobs": []}
for k_idx in range(top_k_tokens.shape[2]):
token_id = top_k_tokens[0, seq_idx, k_idx].item()
prob = top_k_probs[0, seq_idx, k_idx].item()
token_str = id_to_piece[token_id] if token_id < len(id_to_piece) else f"<{token_id}>"
logprob = math.log(prob) if prob > 0 else float("-inf")
entry["top_logprobs"].append({"token": token_str, "logprob": logprob})
self.last_completion_probabilities.append(entry)
def generate(self, prompt, state):
output = ""
for chunk in self.generate_with_streaming(prompt, state):

View file

@ -201,19 +201,23 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
}
).to(input_ids.device).float()
else:
# When processing with labels, handle as a complete sequence
# Process in chunks if the number of tokens is large
# Labels path: use cache for cross-chunk attention.
tokens_to_process = seq_tensor
all_logits = None
current_len = 0
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
chunk = tokens_to_process[i:i + max_chunk_size]
chunk_logits = self.ex_model.forward(
input_ids=chunk.view(1, -1),
params={
"attn_mode": "flash_attn_nc",
"attn_mode": "flash_attn",
"cache": ex_cache,
"past_len": current_len,
"batch_shape": (1, self.max_tokens),
}
).float()
current_len += chunk.shape[0]
if all_logits is None:
all_logits = chunk_logits

View file

@ -6,8 +6,6 @@ from functools import partial
from inspect import signature
from pathlib import Path
import gradio as gr
import modules.shared as shared
from modules.logging_colors import logger
@ -214,6 +212,7 @@ def _apply_custom_js():
def create_extensions_block():
import gradio as gr
to_display = []
for extension, name in iterator():
if hasattr(extension, "ui") and not (hasattr(extension, 'params') and extension.params.get('is_tab', False)):
@ -228,6 +227,7 @@ def create_extensions_block():
def create_extensions_tabs():
import gradio as gr
for extension, name in iterator():
if hasattr(extension, "ui") and (hasattr(extension, 'params') and extension.params.get('is_tab', False)):
display_name = getattr(extension, 'params', {}).get('display_name', name)

View file

@ -10,6 +10,7 @@ import markdown
from PIL import Image, ImageOps
from modules import shared
from modules.reasoning import extract_reasoning
from modules.sane_markdown_lists import SaneListExtension
from modules.utils import get_available_chat_styles
@ -108,69 +109,41 @@ def replace_blockquote(m):
return m.group().replace('\n', '\n> ').replace('\\begin{blockquote}', '').replace('\\end{blockquote}', '')
# Thinking block format definitions: (start_tag, end_tag, content_start_tag)
# Use None for start_tag to match from beginning (end-only formats should be listed last)
THINKING_FORMATS = [
('<think>', '</think>', None),
('<|channel|>analysis<|message|>', '<|end|>', '<|start|>assistant<|channel|>final<|message|>'),
('<seed:think>', '</seed:think>', None),
('<|think|>', '<|end|>', '<|content|>'), # Solar Open
('Thinking Process:', '</think>', None), # Qwen3.5 verbose thinking outside tags
(None, '</think>', None), # End-only variant (e.g., Qwen3-next)
]
def extract_thinking_block(string):
"""Extract thinking blocks from the beginning of a string."""
if not string:
return None, string
for start_tag, end_tag, content_tag in THINKING_FORMATS:
end_esc = html.escape(end_tag)
content_esc = html.escape(content_tag) if content_tag else None
if start_tag is None:
# End-only format: require end tag, start from beginning
end_pos = string.find(end_esc)
if end_pos == -1:
continue
thought_start = 0
else:
# Normal format: require start tag
start_esc = html.escape(start_tag)
start_pos = string.find(start_esc)
if start_pos == -1:
continue
thought_start = start_pos + len(start_esc)
end_pos = string.find(end_esc, thought_start)
if end_pos == -1:
# End tag missing - check if content tag can serve as fallback
if content_esc:
content_pos = string.find(content_esc, thought_start)
if content_pos != -1:
thought_end = content_pos
content_start = content_pos + len(content_esc)
else:
thought_end = len(string)
content_start = len(string)
else:
thought_end = len(string)
content_start = len(string)
else:
thought_end = end_pos
if content_esc:
content_pos = string.find(content_esc, end_pos)
content_start = content_pos + len(content_esc) if content_pos != -1 else end_pos + len(end_esc)
else:
content_start = end_pos + len(end_esc)
return string[thought_start:thought_end], string[content_start:]
return None, string
"""Extract thinking blocks from the beginning of an HTML-escaped string."""
return extract_reasoning(string, html_escaped=True)
def build_thinking_block(thinking_content, message_id, has_remaining_content):
def build_tool_call_block(header, body, message_id, index):
"""Build HTML for a tool call accordion block."""
block_id = f"tool-call-{message_id}-{index}"
if body == '...':
# Pending placeholder — no expandable body, just title with ellipsis
return f'''
<details class="thinking-block" data-block-id="{block_id}">
<summary class="thinking-header">
{tool_svg_small}
<span class="thinking-title">{html.escape(header)} ...</span>
</summary>
</details>
'''
# Build a plain <pre> directly to avoid highlight.js auto-detection
escaped_body = html.escape(body)
return f'''
<details class="thinking-block" data-block-id="{block_id}">
<summary class="thinking-header">
{tool_svg_small}
<span class="thinking-title">{html.escape(header)}</span>
</summary>
<div class="thinking-content pretty_scrollbar"><pre><code class="nohighlight">{escaped_body}</code></pre></div>
</details>
'''
def build_thinking_block(thinking_content, message_id, has_remaining_content, thinking_index=0):
"""Build HTML for a thinking block."""
if thinking_content is None:
return None
@ -179,7 +152,7 @@ def build_thinking_block(thinking_content, message_id, has_remaining_content):
thinking_html = process_markdown_content(thinking_content)
# Generate unique ID for the thinking block
block_id = f"thinking-{message_id}-0"
block_id = f"thinking-{message_id}-{thinking_index}"
# Check if thinking is complete or still in progress
is_streaming = not has_remaining_content
@ -344,6 +317,9 @@ def process_markdown_content(string):
# Unescape backslashes
html_output = html_output.replace('\\\\', '\\')
# Wrap tables in a scrollable div
html_output = html_output.replace('<table>', '<div class="table-wrapper pretty_scrollbar"><table>').replace('</table>', '</table></div>')
return html_output
@ -360,24 +336,66 @@ def convert_to_markdown(string, message_id=None):
if message_id is None:
message_id = "unknown"
# Extract different components from the string
thinking_content, remaining_content = extract_thinking_block(string)
# Find tool call blocks by position, then process the text segments
# between them using extract_thinking_block (which supports all
# THINKING_FORMATS, including end-only variants like Qwen's).
tool_call_pattern = re.compile(r'<tool_call>(.*?)\n(.*?)\n</tool_call>', re.DOTALL)
tool_calls = list(tool_call_pattern.finditer(string))
# Build individual HTML blocks
blocks = []
if not tool_calls:
# No tool calls — use original single-pass extraction
thinking_content, remaining_content = extract_thinking_block(string)
blocks = []
thinking_html = build_thinking_block(thinking_content, message_id, bool(remaining_content))
if thinking_html:
blocks.append(thinking_html)
# Add thinking block if present
thinking_html = build_thinking_block(thinking_content, message_id, bool(remaining_content))
if thinking_html:
blocks.append(thinking_html)
main_html = build_main_content_block(remaining_content)
if main_html:
blocks.append(main_html)
# Add main content block
main_html = build_main_content_block(remaining_content)
if main_html:
blocks.append(main_html)
return ''.join(blocks)
# Assemble all blocks into final HTML
return ''.join(blocks)
# Split string into text segments around tool_call blocks and
# run extract_thinking_block on each segment for full format support.
html_parts = []
last_end = 0
tool_idx = 0
think_idx = 0
def process_text_segment(text, is_last_segment):
"""Process a text segment between tool_call blocks for thinking content."""
nonlocal think_idx
if not text.strip():
return
while text.strip():
thinking_content, remaining = extract_thinking_block(text)
if thinking_content is None:
break
has_remaining = bool(remaining.strip()) or not is_last_segment
html_parts.append(build_thinking_block(thinking_content, message_id, has_remaining, think_idx))
think_idx += 1
text = remaining
if text.strip():
html_parts.append(process_markdown_content(text))
for tc in tool_calls:
# Process text before this tool_call
process_text_segment(string[last_end:tc.start()], is_last_segment=False)
# Add tool call accordion
header = tc.group(1).strip()
body = tc.group(2).strip()
html_parts.append(build_tool_call_block(header, body, message_id, tool_idx))
tool_idx += 1
last_end = tc.end()
# Process text after the last tool_call
process_text_segment(string[last_end:], is_last_segment=True)
return ''.join(html_parts)
def convert_to_markdown_wrapped(string, message_id=None, use_cache=True):
@ -435,6 +453,7 @@ branch_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24
edit_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="tabler-icon tabler-icon-pencil"><path d="M4 20h4l10.5 -10.5a2.828 2.828 0 1 0 -4 -4l-10.5 10.5v4"></path><path d="M13.5 6.5l4 4"></path></svg>'''
info_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="thinking-icon tabler-icon tabler-icon-info-circle"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M12 2a10 10 0 0 1 0 20a10 10 0 0 1 0 -20z" /><path d="M12 16v-4" /><path d="M12 8h.01" /></svg>'''
info_svg_small = '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="thinking-icon tabler-icon tabler-icon-info-circle"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M12 2a10 10 0 0 1 0 20a10 10 0 0 1 0 -20z" /><path d="M12 16v-4" /><path d="M12 8h.01" /></svg>'''
tool_svg_small = '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="thinking-icon tabler-icon tabler-icon-tool"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M7 10h3v-3l-3.5 -3.5a6 6 0 0 1 8 8l6 6a2 2 0 0 1 -3 3l-6 -6a6 6 0 0 1 -8 -8l3.5 3.5" /></svg>'''
attachment_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M21.44 11.05l-9.19 9.19a6 6 0 0 1-8.48-8.48l9.19-9.19a4 4 0 0 1 5.66 5.66l-9.2 9.19a2 2 0 0 1-2.83-2.83l8.49-8.48"></path></svg>'''
copy_button = f'<button class="footer-button footer-copy-button" title="Copy" onclick="copyToClipboard(this)">{copy_svg}</button>'

View file

@ -36,6 +36,7 @@ class LlamaServer:
self.process = None
self.session = requests.Session()
self.vocabulary_size = None
self.n_ctx = None
self.bos_token = "<s>"
self.last_prompt_token_count = 0
@ -133,9 +134,20 @@ class LlamaServer:
payload["samplers"] = filtered_samplers
logit_bias = []
if state['custom_token_bans']:
to_ban = [[int(token_id), False] for token_id in state['custom_token_bans'].split(',')]
payload["logit_bias"] = to_ban
logit_bias.extend([[int(token_id.strip()), False] for token_id in state['custom_token_bans'].split(',') if token_id.strip()])
if state.get('logit_bias'):
for token_id_str, bias in state['logit_bias'].items():
logit_bias.append([int(token_id_str), bias])
if logit_bias:
payload["logit_bias"] = logit_bias
n_probs = state.get('logprobs', 0)
if n_probs and n_probs > 0:
payload["n_probs"] = n_probs
return payload
@ -215,6 +227,7 @@ class LlamaServer:
response.raise_for_status() # Raise an exception for HTTP errors
full_text = ""
self.last_completion_probabilities = []
# Process the streaming response
stop_event = state.get('stop_event')
@ -240,6 +253,10 @@ class LlamaServer:
full_text += data['content']
yield full_text
# Capture logprobs if present
if 'completion_probabilities' in data:
self.last_completion_probabilities.extend(data['completion_probabilities'])
# Check if generation is complete
if data.get('stop', False):
break
@ -304,12 +321,17 @@ class LlamaServer:
self.vocabulary_size = model_info["meta"]["n_vocab"]
def _get_bos_token(self):
"""Get and store the model's BOS token."""
"""Get and store the model's BOS token and context size."""
url = f"http://127.0.0.1:{self.port}/props"
response = self.session.get(url).json()
if "bos_token" in response:
self.bos_token = response["bos_token"]
# Get actual n_ctx from the server (important when --fit auto-selects it)
n_ctx = response.get("default_generation_settings", {}).get("n_ctx")
if n_ctx:
self.n_ctx = n_ctx
def _is_port_available(self, port):
"""Check if a port is available for use."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@ -349,11 +371,14 @@ class LlamaServer:
if shared.args.ctx_size > 0:
cmd += ["--ctx-size", str(shared.args.ctx_size)]
elif shared.args.gpu_layers >= 0:
cmd += ["--ctx-size", "8192"]
if shared.args.gpu_layers >= 0:
cmd += ["--gpu-layers", str(shared.args.gpu_layers), "--fit", "off"]
else:
cmd += ["--fit", "on"]
cmd += ["--fit-ctx", "8192"]
if shared.args.fit_target:
cmd += ["--fit-target", shared.args.fit_target]
@ -379,10 +404,6 @@ class LlamaServer:
if shared.args.cache_type != "fp16" and shared.args.cache_type in llamacpp_valid_cache_types:
cmd += ["--cache-type-k", shared.args.cache_type, "--cache-type-v", shared.args.cache_type]
cache_type = shared.args.cache_type
if shared.args.compress_pos_emb != 1:
cmd += ["--rope-freq-scale", str(1.0 / shared.args.compress_pos_emb)]
if shared.args.rope_freq_base > 0:
cmd += ["--rope-freq-base", str(shared.args.rope_freq_base)]
if shared.args.mmproj not in [None, 'None']:
path = Path(shared.args.mmproj)
if not path.exists():
@ -455,7 +476,7 @@ class LlamaServer:
print()
gpu_layers_str = "auto" if shared.args.gpu_layers < 0 else str(shared.args.gpu_layers)
ctx_size_str = "auto" if shared.args.ctx_size == 0 else str(shared.args.ctx_size)
ctx_size_str = "auto" if shared.args.ctx_size == 0 and shared.args.gpu_layers < 0 else str(shared.args.ctx_size or 8192)
logger.info(f"Using gpu_layers={gpu_layers_str} | ctx_size={ctx_size_str} | cache_type={cache_type}")
# Start the server with pipes for output
self.process = subprocess.Popen(

View file

@ -1,8 +1,6 @@
import functools
from collections import OrderedDict
import gradio as gr
loaders_and_params = OrderedDict({
'llama.cpp': [
'gpu_layers',
@ -17,8 +15,6 @@ loaders_and_params = OrderedDict({
'tensor_split',
'extra_flags',
'streaming_llm',
'rope_freq_base',
'compress_pos_emb',
'row_split',
'no_kv_offload',
'no_mmap',
@ -43,8 +39,6 @@ loaders_and_params = OrderedDict({
'Transformers': [
'gpu_split',
'cpu_memory',
'alpha_value',
'compress_pos_emb',
'compute_dtype',
'quant_type',
'load_in_8bit',
@ -71,7 +65,6 @@ loaders_and_params = OrderedDict({
'gpu_split',
'model_draft',
'draft_max',
'ctx_size_draft',
'speculative_decoding_accordion',
'enable_tp',
'tp_backend',
@ -208,6 +201,7 @@ loaders_samplers = {
'ban_eos_token',
'add_bos_token',
'enable_thinking',
'reasoning_effort',
'seed',
'skip_special_tokens',
},
@ -244,6 +238,7 @@ loaders_samplers = {
'reasoning_effort',
'seed',
'sampler_priority',
'custom_token_bans',
'dry_sequence_breakers',
'grammar_string',
'grammar_file_row',
@ -277,6 +272,7 @@ def list_all_samplers():
def blacklist_samplers(loader, dynamic_temperature):
import gradio as gr
all_samplers = list_all_samplers()
output = []
@ -302,7 +298,58 @@ def get_all_params():
return sorted(all_params)
def list_model_elements():
return [
'filter_by_loader',
'loader',
'cpu_memory',
'gpu_layers',
'fit_target',
'cpu_moe',
'threads',
'threads_batch',
'batch_size',
'ubatch_size',
'ctx_size',
'cache_type',
'tensor_split',
'extra_flags',
'streaming_llm',
'gpu_split',
'compute_dtype',
'quant_type',
'load_in_8bit',
'load_in_4bit',
'attn_implementation',
'cpu',
'disk',
'row_split',
'no_kv_offload',
'no_mmap',
'mlock',
'numa',
'parallel',
'use_double_quant',
'bf16',
'enable_tp',
'tp_backend',
'cfg_cache',
'no_use_fast',
'model_draft',
'draft_max',
'gpu_layers_draft',
'device_draft',
'ctx_size_draft',
'spec_type',
'spec_ngram_size_n',
'spec_ngram_size_m',
'spec_ngram_min_hits',
'mmproj',
]
def make_loader_params_visible(loader):
import gradio as gr
params = []
all_params = get_all_params()
if loader in loaders_and_params:

View file

@ -38,6 +38,9 @@ def load_model(model_name, loader=None):
sampler_hijack.hijack_samplers()
shared.args.loader = loader
if loader != 'llama.cpp' and shared.args.ctx_size == 0:
shared.args.ctx_size = 8192
output = load_func_map[loader](model_name)
if type(output) is tuple:
model, tokenizer = output
@ -54,6 +57,8 @@ def load_model(model_name, loader=None):
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp':
if shared.args.ctx_size > 0:
shared.settings['truncation_length'] = shared.args.ctx_size
elif loader == 'llama.cpp' and hasattr(model, 'n_ctx') and model.n_ctx:
shared.settings['truncation_length'] = model.n_ctx
shared.is_multimodal = False
if loader.lower() in ('exllamav3', 'llama.cpp') and hasattr(model, 'is_multimodal'):

View file

@ -4,10 +4,9 @@ import re
from math import floor
from pathlib import Path
import gradio as gr
import yaml
from modules import chat, loaders, metadata_gguf, shared, ui
from modules import loaders, metadata_gguf, shared
from modules.logging_colors import logger
from modules.utils import resolve_model_path
@ -16,9 +15,6 @@ def get_fallback_settings():
return {
'bf16': False,
'ctx_size': 8192,
'rope_freq_base': 0,
'compress_pos_emb': 1,
'alpha_value': 1,
'truncation_length': shared.settings['truncation_length'],
'truncation_length_info': shared.settings['truncation_length'],
'skip_special_tokens': shared.settings['skip_special_tokens'],
@ -68,14 +64,8 @@ def get_model_metadata(model):
for k in metadata:
if k.endswith('.context_length'):
model_settings['ctx_size'] = min(metadata[k], 8192)
model_settings['ctx_size'] = 0
model_settings['truncation_length_info'] = metadata[k]
elif k.endswith('rope.freq_base'):
model_settings['rope_freq_base'] = metadata[k]
elif k.endswith('rope.scale_linear'):
model_settings['compress_pos_emb'] = metadata[k]
elif k.endswith('rope.scaling.factor'):
model_settings['compress_pos_emb'] = metadata[k]
elif k.endswith('.block_count'):
model_settings['gpu_layers'] = -1
model_settings['max_gpu_layers'] = metadata[k] + 1
@ -120,15 +110,6 @@ def get_model_metadata(model):
model_settings['ctx_size'] = min(value, 8192)
break
if 'rope_theta' in metadata:
model_settings['rope_freq_base'] = metadata['rope_theta']
elif 'attn_config' in metadata and 'rope_theta' in metadata['attn_config']:
model_settings['rope_freq_base'] = metadata['attn_config']['rope_theta']
if 'rope_scaling' in metadata and isinstance(metadata['rope_scaling'], dict) and all(key in metadata['rope_scaling'] for key in ('type', 'factor')):
if metadata['rope_scaling']['type'] == 'linear':
model_settings['compress_pos_emb'] = metadata['rope_scaling']['factor']
if 'torch_dtype' in metadata and metadata['torch_dtype'] == 'bfloat16':
model_settings['bf16'] = True
@ -182,10 +163,6 @@ def get_model_metadata(model):
if 'instruction_template' not in model_settings:
model_settings['instruction_template'] = 'Alpaca'
# Ignore rope_freq_base if set to the default value
if 'rope_freq_base' in model_settings and model_settings['rope_freq_base'] == 10000:
model_settings.pop('rope_freq_base')
# Apply user settings from user_data/models/config-user.yaml
settings = shared.user_config
for pat in settings:
@ -199,7 +176,7 @@ def get_model_metadata(model):
# Load instruction template if defined by name rather than by value
if model_settings['instruction_template'] != 'Custom (obtained from model metadata)':
model_settings['instruction_template_str'] = chat.load_instruction_template(model_settings['instruction_template'])
model_settings['instruction_template_str'] = load_instruction_template(model_settings['instruction_template'])
return model_settings
@ -228,7 +205,7 @@ def update_model_parameters(state, initial=False):
'''
UI: update the command-line arguments based on the interface values
'''
elements = ui.list_model_elements() # the names of the parameters
elements = loaders.list_model_elements() # the names of the parameters
for i, element in enumerate(elements):
if element not in state:
@ -248,6 +225,7 @@ def apply_model_settings_to_state(model, state):
'''
UI: update the state variable with the model settings
'''
import gradio as gr
model_settings = get_model_metadata(model)
if 'loader' in model_settings:
loader = model_settings.pop('loader')
@ -290,7 +268,7 @@ def save_model_settings(model, state):
if model_regex not in user_config:
user_config[model_regex] = {}
for k in ui.list_model_elements():
for k in loaders.list_model_elements():
if k == 'loader' or k in loaders.loaders_and_params[state['loader']]:
user_config[model_regex][k] = state[k]
@ -419,3 +397,103 @@ def update_gpu_layers_and_vram(loader, model, gpu_layers, ctx_size, cache_type):
vram_usage = estimate_vram(model, gpu_layers, ctx_size, cache_type)
return f"<div id=\"vram-info\"'>Estimated VRAM to load the model: <span class=\"value\">{vram_usage:.0f} MiB</span></div>"
def load_instruction_template(template):
if template == 'None':
return ''
for filepath in [shared.user_data_dir / 'instruction-templates' / f'{template}.yaml', shared.user_data_dir / 'instruction-templates' / 'Alpaca.yaml']:
if filepath.exists():
break
else:
return ''
with open(filepath, 'r', encoding='utf-8') as f:
file_contents = f.read()
data = yaml.safe_load(file_contents)
if 'instruction_template' in data:
return data['instruction_template']
else:
return _jinja_template_from_old_format(data)
def _jinja_template_from_old_format(params, verbose=False):
MASTER_TEMPLATE = """
{%- set ns = namespace(found=false) -%}
{%- for message in messages -%}
{%- if message['role'] == 'system' -%}
{%- set ns.found = true -%}
{%- endif -%}
{%- endfor -%}
{%- if not ns.found -%}
{{- '<|PRE-SYSTEM|>' + '<|SYSTEM-MESSAGE|>' + '<|POST-SYSTEM|>' -}}
{%- endif %}
{%- for message in messages %}
{%- if message['role'] == 'system' -%}
{{- '<|PRE-SYSTEM|>' + message['content'] + '<|POST-SYSTEM|>' -}}
{%- else -%}
{%- if message['role'] == 'user' -%}
{{-'<|PRE-USER|>' + message['content'] + '<|POST-USER|>'-}}
{%- else -%}
{{-'<|PRE-ASSISTANT|>' + message['content'] + '<|POST-ASSISTANT|>' -}}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{-'<|PRE-ASSISTANT-GENERATE|>'-}}
{%- endif -%}
"""
if 'context' in params and '<|system-message|>' in params['context']:
pre_system = params['context'].split('<|system-message|>')[0]
post_system = params['context'].split('<|system-message|>')[1]
else:
pre_system = ''
post_system = ''
pre_user = params['turn_template'].split('<|user-message|>')[0].replace('<|user|>', params['user'])
post_user = params['turn_template'].split('<|user-message|>')[1].split('<|bot|>')[0]
pre_assistant = '<|bot|>' + params['turn_template'].split('<|bot-message|>')[0].split('<|bot|>')[1]
pre_assistant = pre_assistant.replace('<|bot|>', params['bot'])
post_assistant = params['turn_template'].split('<|bot-message|>')[1]
def preprocess(string):
return string.replace('\n', '\\n').replace('\'', '\\\'')
pre_system = preprocess(pre_system)
post_system = preprocess(post_system)
pre_user = preprocess(pre_user)
post_user = preprocess(post_user)
pre_assistant = preprocess(pre_assistant)
post_assistant = preprocess(post_assistant)
if verbose:
print(
'\n',
repr(pre_system) + '\n',
repr(post_system) + '\n',
repr(pre_user) + '\n',
repr(post_user) + '\n',
repr(pre_assistant) + '\n',
repr(post_assistant) + '\n',
)
result = MASTER_TEMPLATE
if 'system_message' in params:
result = result.replace('<|SYSTEM-MESSAGE|>', preprocess(params['system_message']))
else:
result = result.replace('<|SYSTEM-MESSAGE|>', '')
result = result.replace('<|PRE-SYSTEM|>', pre_system)
result = result.replace('<|POST-SYSTEM|>', post_system)
result = result.replace('<|PRE-USER|>', pre_user)
result = result.replace('<|POST-USER|>', post_user)
result = result.replace('<|PRE-ASSISTANT|>', pre_assistant)
result = result.replace('<|PRE-ASSISTANT-GENERATE|>', pre_assistant.rstrip(' '))
result = result.replace('<|POST-ASSISTANT|>', post_assistant)
result = result.strip()
return result

View file

@ -16,9 +16,10 @@ default_preset_values = {
'dynatemp_exponent': 1,
'smoothing_factor': 0,
'smoothing_curve': 1,
'min_p': 0,
'top_p': 1,
'top_k': 0,
'min_p': 0,
'top_n_sigma': 0,
'typical_p': 1,
'xtc_threshold': 0.1,
'xtc_probability': 0,
@ -26,7 +27,6 @@ default_preset_values = {
'eta_cutoff': 0,
'tfs': 1,
'top_a': 0,
'top_n_sigma': 0,
'adaptive_target': 0,
'adaptive_decay': 0.9,
'dry_multiplier': 0,

94
modules/reasoning.py Normal file
View file

@ -0,0 +1,94 @@
import html as html_module
# Thinking block format definitions: (start_tag, end_tag, content_start_tag)
# Use None for start_tag to match from beginning (end-only formats should be listed last)
THINKING_FORMATS = [
('<think>', '</think>', None),
('<|channel|>analysis<|message|>', '<|end|>', '<|channel|>final<|message|>'),
('<|channel|>commentary<|message|>', '<|end|>', '<|channel|>final<|message|>'),
('<seed:think>', '</seed:think>', None),
('<|think|>', '<|end|>', '<|content|>'), # Solar Open
# ('Thinking Process:', '</think>', None), # Qwen3.5 verbose thinking outside tags -- removed: too prone to false positives in streaming
(None, '</think>', None), # End-only variant (e.g., Qwen3-next)
]
def extract_reasoning(text, html_escaped=False):
"""Extract reasoning/thinking blocks from the beginning of a string.
When html_escaped=True, tags are HTML-escaped before searching
(for use on already-escaped UI strings).
Returns (reasoning_content, final_content) where reasoning_content is
None if no thinking block is found.
"""
if not text:
return None, text
esc = html_module.escape if html_escaped else lambda s: s
for start_tag, end_tag, content_tag in THINKING_FORMATS:
end_esc = esc(end_tag)
content_esc = esc(content_tag) if content_tag else None
if start_tag is None:
# End-only format: require end tag, start from beginning
end_pos = text.find(end_esc)
if end_pos == -1:
continue
thought_start = 0
else:
# Normal format: require start tag
start_esc = esc(start_tag)
start_pos = text.find(start_esc)
if start_pos == -1:
# During streaming, the start tag may be arriving partially.
# If the text is a prefix of a start tag, return empty content
# to prevent the partial tag from leaking.
stripped = text.strip()
if stripped and start_esc.startswith(stripped):
return '', ''
continue
thought_start = start_pos + len(start_esc)
end_pos = text.find(end_esc, thought_start)
if end_pos == -1:
# End tag missing - check if content tag can serve as fallback
if content_esc:
content_pos = text.find(content_esc, thought_start)
if content_pos != -1:
thought_end = content_pos
content_start = content_pos + len(content_esc)
else:
thought_end = len(text)
content_start = len(text)
else:
thought_end = len(text)
content_start = len(text)
else:
thought_end = end_pos
if content_esc:
content_pos = text.find(content_esc, end_pos)
if content_pos != -1:
content_start = content_pos + len(content_esc)
else:
# Content tag expected but not yet present (e.g. partial
# streaming) — suppress intermediate tags between end_tag
# and content_tag so they don't leak as content.
content_start = len(text)
else:
content_start = end_pos + len(end_esc)
return text[thought_start:thought_end], text[content_start:]
# Handle standalone GPT-OSS final channel marker without a preceding
# analysis/commentary block (the model skipped thinking entirely).
for marker in ['<|start|>assistant<|channel|>final<|message|>', '<|channel|>final<|message|>']:
marker_esc = esc(marker)
pos = text.find(marker_esc)
if pos != -1:
before = text[:pos].strip()
after = text[pos + len(marker_esc):]
return (before if before else None), after
return None, text

View file

@ -47,7 +47,7 @@ parser = argparse.ArgumentParser(description="Text Generation Web UI", conflict_
# Basic settings
group = parser.add_argument_group('Basic settings')
group.add_argument('--user-data-dir', type=str, default=str(user_data_dir), help='Path to the user data directory. Default: auto-detected.')
group.add_argument('--multi-user', action='store_true', help='Multi-user mode. Chat histories are not saved or automatically loaded. Warning: this is likely not safe for sharing publicly.')
group.add_argument('--multi-user', action='store_true', help='Multi-user mode. Chat histories are not saved or automatically loaded. Best suited for small trusted teams.')
group.add_argument('--model', type=str, help='Name of the model to load by default.')
group.add_argument('--lora', type=str, nargs='+', help='The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.')
group.add_argument('--model-dir', type=str, default=str(user_data_dir / 'models'), help='Path to directory with all the models.')
@ -76,7 +76,7 @@ group.add_argument('--loader', type=str, help='Choose the model loader manually,
# Cache
group = parser.add_argument_group('Context and cache')
group.add_argument('--ctx-size', '--n_ctx', '--max_seq_len', type=int, default=8192, metavar='N', help='Context size in tokens. llama.cpp: 0 = auto if gpu-layers is also -1.')
group.add_argument('--ctx-size', '--n_ctx', '--max_seq_len', type=int, default=0, metavar='N', help='Context size in tokens. 0 = auto for llama.cpp (requires gpu-layers=-1), 8192 for other loaders.')
group.add_argument('--cache-type', '--cache_type', type=str, default='fp16', metavar='N', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8).')
# Speculative decoding
@ -108,7 +108,7 @@ group.add_argument('--threads', type=int, default=0, help='Number of threads to
group.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.')
group.add_argument('--numa', action='store_true', help='Activate NUMA task allocation for llama.cpp.')
group.add_argument('--parallel', type=int, default=1, help='Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set ctx_size to 32768.')
group.add_argument('--fit-target', type=str, default='1024', help='Target VRAM margin per device for auto GPU layers, comma-separated list of values in MiB. A single value is broadcast across all devices. Default: 1024.')
group.add_argument('--fit-target', type=str, default='512', help='Target VRAM margin per device for auto GPU layers, comma-separated list of values in MiB. A single value is broadcast across all devices.')
group.add_argument('--extra-flags', type=str, default=None, help='Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"')
# Transformers/Accelerate
@ -139,12 +139,6 @@ group.add_argument('--enable-tp', '--enable_tp', action='store_true', help='Enab
group.add_argument('--tp-backend', type=str, default='native', help='The backend for tensor parallelism. Valid options: native, nccl. Default: native.')
group.add_argument('--cfg-cache', action='store_true', help='Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.')
# RoPE
group = parser.add_argument_group('RoPE')
group.add_argument('--alpha_value', type=float, default=1, help='Positional embeddings alpha factor for NTK RoPE scaling. Use either this or compress_pos_emb, not both.')
group.add_argument('--rope_freq_base', type=int, default=0, help='If greater than 0, will be used instead of alpha_value. Those two are related by rope_freq_base = 10000 * alpha_value ^ (64 / 63).')
group.add_argument('--compress_pos_emb', type=int, default=1, help="Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.")
# Gradio
group = parser.add_argument_group('Gradio')
group.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
@ -163,7 +157,7 @@ group.add_argument('--portable', action='store_true', help='Hide features not av
# API
group = parser.add_argument_group('API')
group.add_argument('--api', action='store_true', help='Enable the API extension.')
group.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
group.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudflare.')
group.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None)
group.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.')
group.add_argument('--api-key', type=str, default='', help='API authentication key.')
@ -181,9 +175,10 @@ group.add_argument('--dynatemp-high', type=float, default=_d['dynatemp_high'], m
group.add_argument('--dynatemp-exponent', type=float, default=_d['dynatemp_exponent'], metavar='N', help='Dynamic temperature exponent')
group.add_argument('--smoothing-factor', type=float, default=_d['smoothing_factor'], metavar='N', help='Smoothing factor')
group.add_argument('--smoothing-curve', type=float, default=_d['smoothing_curve'], metavar='N', help='Smoothing curve')
group.add_argument('--min-p', type=float, default=_d['min_p'], metavar='N', help='Min P')
group.add_argument('--top-p', type=float, default=_d['top_p'], metavar='N', help='Top P')
group.add_argument('--top-k', type=int, default=_d['top_k'], metavar='N', help='Top K')
group.add_argument('--min-p', type=float, default=_d['min_p'], metavar='N', help='Min P')
group.add_argument('--top-n-sigma', type=float, default=_d['top_n_sigma'], metavar='N', help='Top N Sigma')
group.add_argument('--typical-p', type=float, default=_d['typical_p'], metavar='N', help='Typical P')
group.add_argument('--xtc-threshold', type=float, default=_d['xtc_threshold'], metavar='N', help='XTC threshold')
group.add_argument('--xtc-probability', type=float, default=_d['xtc_probability'], metavar='N', help='XTC probability')
@ -191,7 +186,6 @@ group.add_argument('--epsilon-cutoff', type=float, default=_d['epsilon_cutoff'],
group.add_argument('--eta-cutoff', type=float, default=_d['eta_cutoff'], metavar='N', help='Eta cutoff')
group.add_argument('--tfs', type=float, default=_d['tfs'], metavar='N', help='TFS')
group.add_argument('--top-a', type=float, default=_d['top_a'], metavar='N', help='Top A')
group.add_argument('--top-n-sigma', type=float, default=_d['top_n_sigma'], metavar='N', help='Top N Sigma')
group.add_argument('--adaptive-target', type=float, default=_d['adaptive_target'], metavar='N', help='Adaptive target')
group.add_argument('--adaptive-decay', type=float, default=_d['adaptive_decay'], metavar='N', help='Adaptive decay')
group.add_argument('--dry-multiplier', type=float, default=_d['dry_multiplier'], metavar='N', help='DRY multiplier')
@ -263,8 +257,9 @@ settings = {
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>". Reply directly, without starting the reply with the character name.\n\n<|prompt|>',
'enable_web_search': False,
'web_search_pages': 3,
'selected_tools': [],
'prompt-notebook': '',
'preset': 'Qwen3 - Thinking' if (user_data_dir / 'presets/Qwen3 - Thinking.yaml').exists() else None,
'preset': 'Top-P' if (user_data_dir / 'presets/Top-P.yaml').exists() else None,
'max_new_tokens': 512,
'max_new_tokens_min': 1,
'max_new_tokens_max': 4096,
@ -289,7 +284,7 @@ settings = {
'include_past_attachments': True,
# Generation parameters - Curve shape
'temperature': 0.6,
'temperature': neutral_samplers['temperature'],
'dynatemp_low': neutral_samplers['dynatemp_low'],
'dynatemp_high': neutral_samplers['dynatemp_high'],
'dynatemp_exponent': neutral_samplers['dynatemp_exponent'],
@ -297,9 +292,10 @@ settings = {
'smoothing_curve': neutral_samplers['smoothing_curve'],
# Generation parameters - Curve cutoff
'min_p': neutral_samplers['min_p'],
'top_p': 0.95,
'top_k': 20,
'top_k': neutral_samplers['top_k'],
'min_p': neutral_samplers['min_p'],
'top_n_sigma': neutral_samplers['top_n_sigma'],
'typical_p': neutral_samplers['typical_p'],
'xtc_threshold': neutral_samplers['xtc_threshold'],
'xtc_probability': neutral_samplers['xtc_probability'],
@ -307,7 +303,6 @@ settings = {
'eta_cutoff': neutral_samplers['eta_cutoff'],
'tfs': neutral_samplers['tfs'],
'top_a': neutral_samplers['top_a'],
'top_n_sigma': neutral_samplers['top_n_sigma'],
'adaptive_target': neutral_samplers['adaptive_target'],
'adaptive_decay': neutral_samplers['adaptive_decay'],
@ -347,7 +342,7 @@ settings = {
'greeting': 'How can I help you today?',
'custom_system_message': '',
'instruction_template_str': "{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if not ns.found -%}\n {{- '' + 'Below is an instruction that describes a task. Write a response that appropriately completes the request.' + '\\n\\n' -}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {{- '' + message['content'] + '\\n\\n' -}}\n {%- else -%}\n {%- if message['role'] == 'user' -%}\n {{-'### Instruction:\\n' + message['content'] + '\\n\\n'-}}\n {%- else -%}\n {{-'### Response:\\n' + message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{-'### Response:\\n'-}}\n{%- endif -%}",
'chat_template_str': "{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {%- if message['content'] -%}\n {{- message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- if user_bio -%}\n {{- user_bio + '\\n\\n' -}}\n {%- endif -%}\n {%- else -%}\n {%- if message['role'] == 'user' -%}\n {{- name1 + ': ' + message['content'] + '\\n'-}}\n {%- else -%}\n {{- name2 + ': ' + message['content'] + '\\n' -}}\n {%- endif -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt %}\n {{- name2 + ':' -}}\n{%- endif %}",
'chat_template_str': "{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {%- if message['content'] -%}\n {{- message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- if user_bio -%}\n {{- user_bio + '\\n\\n' -}}\n {%- endif -%}\n {%- elif message['role'] == 'tool' -%}\n {{- '[Tool result: ' + message['content'] + ']\\n' -}}\n {%- elif message['role'] == 'user' -%}\n {{- name1 + ': ' + message['content'] + '\\n'-}}\n {%- elif message['tool_calls'] is defined and message['tool_calls'] -%}\n {%- for tc in message['tool_calls'] -%}\n {{- '[Calling: ' + tc['function']['name'] + '(' + tc['function']['arguments'] + ')]\\n' -}}\n {%- endfor -%}\n {%- else -%}\n {{- name2 + ': ' + message['content'] + '\\n' -}}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt %}\n {{- name2 + ':' -}}\n{%- endif %}",
# Extensions
'default_extensions': [],
@ -395,9 +390,16 @@ def do_cmd_flags_warnings():
if args.share:
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)):
logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
if args.multi_user:
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.')
logger.warning("You are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
if args.multi_user:
logger.warning(
'Multi-user mode is enabled. Known limitations:'
'\n- The Stop button stops generation for all users, not just you.'
'\n- Chat history is not saved and will be lost on page refresh.'
'\n- Only one user can generate at a time unless using a parallel-capable backend (e.g. llama.cpp with --parallel N for N > 1, or ExLlamaV3).'
'\n\nThis mode works best for small trusted teams.'
'\n\nDo not expose publicly. Grayed-out actions can easily be bypassed client-side.\n'
)
def apply_image_model_cli_overrides():

View file

@ -78,10 +78,13 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
reply = ''
is_stream = state['stream']
if len(all_stop_strings) > 0 and not state['stream']:
original_logits_processor = state.get('logits_processor')
stop_event_ref = state.pop('stop_event', None)
state = copy.deepcopy(state)
if stop_event_ref is not None:
state['stop_event'] = stop_event_ref
if original_logits_processor is not None:
state['logits_processor'] = original_logits_processor
state['stream'] = True
# Generate
@ -375,7 +378,7 @@ def generate_reply_HF(question, original_question, state, stopping_strings=None,
generate_params['sampler_priority'] = [x.strip() for x in state['sampler_priority'].replace('\n', ',').split(',') if x.strip()]
if state['custom_token_bans']:
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
to_ban = [int(x.strip()) for x in state['custom_token_bans'].split(',') if x.strip()]
if len(to_ban) > 0:
if generate_params.get('suppress_tokens', None):
generate_params['suppress_tokens'] += to_ban

667
modules/tool_parsing.py Normal file
View file

@ -0,0 +1,667 @@
import json
import random
import re
def get_tool_call_id() -> str:
letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789"
b = [random.choice(letter_bytes) for _ in range(8)]
return "call_" + "".join(b).lower()
# All known opening markers for tool calls across model formats.
TOOL_CALL_OPENING_MARKERS = [
'<tool_call>',
'<function_call>',
'<minimax:tool_call>',
'<|tool_call_begin|>',
'<|tool_calls_section_begin|>',
'<tool▁call▁begin>',
'<tool▁calls▁begin>',
'[TOOL_CALLS]',
'to=functions.',
'<|channel|>commentary',
]
def streaming_tool_buffer_check(text, markers=None, tool_names=None, check_bare_names=False):
'''
Check whether streaming output should be withheld because it may
contain tool-call markup.
Args:
text: Full accumulated internal text.
markers: Template-specific markers for partial-prefix matching.
If None, falls back to TOOL_CALL_OPENING_MARKERS.
tool_names: List of tool function names.
check_bare_names: Whether to do partial-prefix matching on tool
names (for models with unknown template format).
'''
# Full marker found in text → buffer permanently.
# Always checks ALL known markers regardless of template (cheap safety net).
for marker in TOOL_CALL_OPENING_MARKERS:
if marker in text:
return True
# Bare function-name full match: "get_weather{...}" or "get_weather {...}"
if tool_names:
for name in tool_names:
if name + '{' in text or name + ' {' in text:
return True
# Partial-prefix matching: only for template-specific markers.
for marker in (markers if markers is not None else TOOL_CALL_OPENING_MARKERS):
for prefix_len in range(min(len(marker) - 1, len(text)), 0, -1):
if text.endswith(marker[:prefix_len]):
return True
# Bare-name partial matching: only when template format is unknown.
if check_bare_names and tool_names:
for name in tool_names:
if text.endswith(name):
return True
for prefix_len in range(min(len(name) - 1, len(text)), 0, -1):
if text.endswith(name[:prefix_len]):
return True
return False
def check_and_sanitize_tool_call_candidate(candidate_dict: dict, tool_names: list[str]):
# check if property 'function' exists and is a dictionary, otherwise adapt dict
if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str):
candidate_dict = {"type": "function", "function": candidate_dict}
if 'function' in candidate_dict and isinstance(candidate_dict['function'], str):
candidate_dict['name'] = candidate_dict['function']
del candidate_dict['function']
candidate_dict = {"type": "function", "function": candidate_dict}
if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict):
# check if 'name' exists within 'function' and is part of known tools
if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names:
candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value
# map property 'parameters' used by some older models to 'arguments'
if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]:
candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"]
del candidate_dict["function"]["parameters"]
return candidate_dict
return None
def _extract_balanced_json(text: str, start: int) -> str | None:
"""Extract a balanced JSON object from text starting at the given position.
Walks through the string tracking brace depth and string boundaries
to correctly handle arbitrary nesting levels.
"""
if start >= len(text) or text[start] != '{':
return None
depth = 0
in_string = False
escape_next = False
for i in range(start, len(text)):
c = text[i]
if escape_next:
escape_next = False
continue
if c == '\\' and in_string:
escape_next = True
continue
if c == '"':
in_string = not in_string
continue
if in_string:
continue
if c == '{':
depth += 1
elif c == '}':
depth -= 1
if depth == 0:
return text[start:i + 1]
return None
def _parse_channel_tool_calls(answer: str, tool_names: list[str]):
"""Parse channel-based tool calls used by GPT-OSS and similar models.
Format:
<|start|>assistant to=functions.func_name<|channel|>commentary json<|message|>{"arg": "value"}
or:
<|channel|>commentary to=functions.func_name <|constrain|>json<|message|>{"arg": "value"}
"""
matches = []
start_pos = None
# Pattern 1: to=functions.NAME before <|channel|> (GPT-OSS primary format)
# Pattern 2: to=functions.NAME after <|channel|> (alternative format)
patterns = [
r'to=functions\.([^<\s]+)\s*<\|channel\|>[^<]*<\|message\|>',
r'<\|channel\|>\w+ to=functions\.([^<\s]+).*?<\|message\|>',
]
for pattern in patterns:
for m in re.finditer(pattern, answer):
func_name = m.group(1).strip()
if func_name not in tool_names:
continue
json_str = _extract_balanced_json(answer, m.end())
if json_str is None:
continue
try:
arguments = json.loads(json_str)
if start_pos is None:
prefix = answer.rfind('<|start|>assistant', 0, m.start())
start_pos = prefix if prefix != -1 else m.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
except json.JSONDecodeError:
pass
if matches:
break
return matches, start_pos
def _parse_mistral_token_tool_calls(answer: str, tool_names: list[str]):
"""Parse Mistral/Devstral-style tool calls with [TOOL_CALLS] and [ARGS] special tokens.
Format:
[TOOL_CALLS]func_name[ARGS]{"arg": "value"}
"""
matches = []
start_pos = None
for m in re.finditer(
r'\[TOOL_CALLS\]\s*(\S+?)\s*\[ARGS\]\s*',
answer
):
func_name = m.group(1).strip()
if func_name not in tool_names:
continue
json_str = _extract_balanced_json(answer, m.end())
if json_str is None:
continue
try:
arguments = json.loads(json_str)
if start_pos is None:
start_pos = m.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
except json.JSONDecodeError:
pass
return matches, start_pos
def _parse_bare_name_tool_calls(answer: str, tool_names: list[str]):
"""Parse bare function-name style tool calls used by Mistral and similar models.
Format:
functionName{"arg": "value"}
Multiple calls are concatenated directly or separated by whitespace.
"""
matches = []
start_pos = None
# Match tool name followed by opening brace, then extract balanced JSON
escaped_names = [re.escape(name) for name in tool_names]
pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{'
for match in re.finditer(pattern, answer):
text = match.group(0)
name = None
for n in tool_names:
if text.startswith(n):
name = n
break
if not name:
continue
brace_start = match.end() - 1
json_str = _extract_balanced_json(answer, brace_start)
if json_str is None:
continue
try:
arguments = json.loads(json_str)
if start_pos is None:
start_pos = match.start()
matches.append({
"type": "function",
"function": {
"name": name,
"arguments": arguments
}
})
except json.JSONDecodeError:
pass
return matches, start_pos
def _parse_xml_param_tool_calls(answer: str, tool_names: list[str]):
"""Parse XML-parameter style tool calls used by Qwen3.5 and similar models.
Format:
<tool_call>
<function=function_name>
<parameter=param_name>value</parameter>
</function>
</tool_call>
"""
matches = []
start_pos = None
for tc_match in re.finditer(r'<tool_call>\s*(.*?)\s*</tool_call>', answer, re.DOTALL):
tc_content = tc_match.group(1)
func_match = re.search(r'<function=([^>]+)>', tc_content)
if not func_match:
continue
func_name = func_match.group(1).strip()
if func_name not in tool_names:
continue
arguments = {}
for param_match in re.finditer(r'<parameter=([^>]+)>\s*(.*?)\s*</parameter>', tc_content, re.DOTALL):
param_name = param_match.group(1).strip()
param_value = param_match.group(2).strip()
try:
param_value = json.loads(param_value)
except (json.JSONDecodeError, ValueError):
pass # keep as string
arguments[param_name] = param_value
if start_pos is None:
start_pos = tc_match.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
return matches, start_pos
def _parse_kimi_tool_calls(answer: str, tool_names: list[str]):
"""Parse Kimi-K2-style tool calls using pipe-delimited tokens.
Format:
<|tool_calls_section_begin|>
<|tool_call_begin|>functions.func_name:index<|tool_call_argument_begin|>{"arg": "value"}<|tool_call_end|>
<|tool_calls_section_end|>
"""
matches = []
start_pos = None
for m in re.finditer(
r'<\|tool_call_begin\|>\s*(?:functions\.)?(\S+?)(?::\d+)?\s*<\|tool_call_argument_begin\|>\s*',
answer
):
func_name = m.group(1).strip()
if func_name not in tool_names:
continue
json_str = _extract_balanced_json(answer, m.end())
if json_str is None:
continue
try:
arguments = json.loads(json_str)
if start_pos is None:
# Check for section begin marker before the call marker
section = answer.rfind('<|tool_calls_section_begin|>', 0, m.start())
start_pos = section if section != -1 else m.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
except json.JSONDecodeError:
pass
return matches, start_pos
def _parse_minimax_tool_calls(answer: str, tool_names: list[str]):
"""Parse MiniMax-style tool calls using invoke/parameter XML tags.
Format:
<minimax:tool_call>
<invoke name="function_name">
<parameter name="param_name">value</parameter>
</invoke>
</minimax:tool_call>
"""
matches = []
start_pos = None
for tc_match in re.finditer(r'<minimax:tool_call>\s*(.*?)\s*</minimax:tool_call>', answer, re.DOTALL):
tc_content = tc_match.group(1)
# Split on <invoke> to handle multiple parallel calls in one block
for invoke_match in re.finditer(r'<invoke\s+name="([^"]+)">(.*?)</invoke>', tc_content, re.DOTALL):
func_name = invoke_match.group(1).strip()
if func_name not in tool_names:
continue
invoke_body = invoke_match.group(2)
arguments = {}
for param_match in re.finditer(r'<parameter\s+name="([^"]+)">\s*(.*?)\s*</parameter>', invoke_body, re.DOTALL):
param_name = param_match.group(1).strip()
param_value = param_match.group(2).strip()
try:
param_value = json.loads(param_value)
except (json.JSONDecodeError, ValueError):
pass # keep as string
arguments[param_name] = param_value
if start_pos is None:
start_pos = tc_match.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
return matches, start_pos
def _parse_deep_seek_tool_calls(answer: str, tool_names: list[str]):
"""Parse DeepSeek-style tool calls using fullwidth Unicode token delimiters.
Format:
<toolcallsbegin><toolcallbegin>func_name<toolsep>{"arg": "value"}<toolcallend><toolcallsend>
"""
matches = []
start_pos = None
for m in re.finditer(
r'<tool▁call▁begin>\s*(\S+?)\s*<tool▁sep>\s*',
answer
):
func_name = m.group(1).strip()
if func_name not in tool_names:
continue
json_str = _extract_balanced_json(answer, m.end())
if json_str is None:
continue
try:
arguments = json.loads(json_str)
if start_pos is None:
# Check for section begin marker before the call marker
section = answer.rfind('<tool▁calls▁begin>', 0, m.start())
start_pos = section if section != -1 else m.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
except json.JSONDecodeError:
pass
return matches, start_pos
def _parse_glm_tool_calls(answer: str, tool_names: list[str]):
"""Parse GLM-style tool calls using arg_key/arg_value XML pairs.
Format:
<tool_call>function_name
<arg_key>key1</arg_key>
<arg_value>value1</arg_value>
</tool_call>
"""
matches = []
start_pos = None
for tc_match in re.finditer(r'<tool_call>\s*(.*?)\s*</tool_call>', answer, re.DOTALL):
tc_content = tc_match.group(1)
# First non-tag text is the function name
name_match = re.match(r'([^<\s]+)', tc_content.strip())
if not name_match:
continue
func_name = name_match.group(1).strip()
if func_name not in tool_names:
continue
# Extract arg_key/arg_value pairs
keys = [k.group(1).strip() for k in re.finditer(r'<arg_key>\s*(.*?)\s*</arg_key>', tc_content, re.DOTALL)]
vals = [v.group(1).strip() for v in re.finditer(r'<arg_value>\s*(.*?)\s*</arg_value>', tc_content, re.DOTALL)]
if len(keys) != len(vals):
continue
arguments = {}
for k, v in zip(keys, vals):
try:
v = json.loads(v)
except (json.JSONDecodeError, ValueError):
pass # keep as string
arguments[k] = v
if start_pos is None:
start_pos = tc_match.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
return matches, start_pos
def _parse_pythonic_tool_calls(answer: str, tool_names: list[str]):
"""Parse pythonic-style tool calls used by Llama 4 and similar models.
Format:
[func_name(param1="value1", param2="value2"), func_name2(...)]
"""
matches = []
start_pos = None
# Match a bracketed list of function calls
bracket_match = re.search(r'\[([^\[\]]+)\]', answer)
if not bracket_match:
return matches, start_pos
inner = bracket_match.group(1)
# Build pattern for known tool names
escaped_names = [re.escape(name) for name in tool_names]
name_pattern = '|'.join(escaped_names)
for call_match in re.finditer(
r'(' + name_pattern + r')\(([^)]*)\)',
inner
):
func_name = call_match.group(1)
params_str = call_match.group(2).strip()
arguments = {}
if params_str:
# Parse key="value" pairs, handling commas inside quoted values
for param_match in re.finditer(
r'(\w+)\s*=\s*("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\)]+)',
params_str
):
param_name = param_match.group(1)
param_value = param_match.group(2).strip()
# Strip surrounding quotes
if (param_value.startswith('"') and param_value.endswith('"')) or \
(param_value.startswith("'") and param_value.endswith("'")):
param_value = param_value[1:-1]
# Try to parse as JSON for numeric/bool/null values
try:
param_value = json.loads(param_value)
except (json.JSONDecodeError, ValueError):
pass
arguments[param_name] = param_value
if start_pos is None:
start_pos = bracket_match.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
return matches, start_pos
# Format registry: maps template substrings to the parser and streaming
# markers for that format. When a format's hints are NOT found in the
# template, its parser and markers are excluded.
TOOL_CALL_FORMATS = [
{
'template_hints': ['tool▁call▁begin', 'tool▁calls▁begin'],
'parser': _parse_deep_seek_tool_calls,
'markers': ['<tool▁call▁begin>', '<tool▁calls▁begin>'],
},
{
'template_hints': ['<|tool_call_begin|>', 'tool_calls_section'],
'parser': _parse_kimi_tool_calls,
'markers': ['<|tool_call_begin|>', '<|tool_calls_section_begin|>'],
},
{
'template_hints': ['to=functions.', '<|channel|>'],
'parser': _parse_channel_tool_calls,
'markers': ['to=functions.', '<|channel|>commentary'],
},
{
'template_hints': ['minimax:tool_call'],
'parser': _parse_minimax_tool_calls,
'markers': ['<minimax:tool_call>'],
},
{
'template_hints': ['<arg_key>'],
'parser': _parse_glm_tool_calls,
'markers': ['<tool_call>'],
},
{
'template_hints': ['<tool_call>'],
'parser': _parse_xml_param_tool_calls,
'markers': ['<tool_call>'],
},
{
'template_hints': ['[TOOL_CALLS]'],
'parser': _parse_mistral_token_tool_calls,
'markers': ['[TOOL_CALLS]'],
},
{
'template_hints': ['<function_call>'],
'parser': None,
'markers': ['<function_call>'],
},
]
# Default ordered list of all specialized parsers.
ALL_PARSERS = [
_parse_deep_seek_tool_calls,
_parse_kimi_tool_calls,
_parse_channel_tool_calls,
_parse_minimax_tool_calls,
_parse_glm_tool_calls,
_parse_xml_param_tool_calls,
_parse_mistral_token_tool_calls,
_parse_bare_name_tool_calls,
_parse_pythonic_tool_calls,
]
def detect_tool_call_format(template_str):
"""Inspect a chat/instruction template to determine which tool call
formats are relevant.
Uses an exclude-based approach: starts with all parsers/markers,
then removes the ones whose hints are not found in the template.
Returns (parsers, streaming_markers, check_bare_names).
"""
if not template_str:
return None, TOOL_CALL_OPENING_MARKERS, True
matched_any = False
exclude_parsers = []
exclude_markers = []
matched_markers = []
for fmt in TOOL_CALL_FORMATS:
if any(hint in template_str for hint in fmt['template_hints']):
matched_any = True
matched_markers.extend(fmt['markers'])
else:
if fmt['parser'] is not None:
exclude_parsers.append(fmt['parser'])
exclude_markers.extend(fmt['markers'])
if not matched_any:
return None, TOOL_CALL_OPENING_MARKERS, True
parsers = [p for p in ALL_PARSERS if p not in exclude_parsers]
markers = [m for m in TOOL_CALL_OPENING_MARKERS if m not in exclude_markers or m in matched_markers]
return parsers, markers, False
def parse_tool_call(answer: str, tool_names: list[str], return_prefix: bool = False, parsers: list = None):
matches = []
start_pos = None
def _return(matches, start_pos):
if return_prefix:
prefix = answer[:start_pos] if matches and start_pos is not None else ''
return matches, prefix
return matches
# Try specialized parsers.
for parser in (parsers if parsers is not None else ALL_PARSERS):
matches, start_pos = parser(answer, tool_names)
if matches:
return _return(matches, start_pos)
# Generic fallback: regex pattern to find the JSON content wrapped in <function>, <tools>, <tool_call>, and other tags observed from various models
patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)</\1>"]
for pattern in patterns:
for match in re.finditer(pattern, answer, re.DOTALL):
if match.group(2) is None:
continue
# remove backtick wraps if present
candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip())
candidate = re.sub(r"```$", "", candidate.strip())
# unwrap inner tags
candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL)
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
if not candidate.strip().startswith("["):
candidate = "[" + candidate + "]"
candidates = []
try:
# parse the candidate JSON into a dictionary
candidates = json.loads(candidate)
if not isinstance(candidates, list):
candidates = [candidates]
except json.JSONDecodeError:
# Ignore invalid JSON silently
continue
for candidate_dict in candidates:
checked_candidate = check_and_sanitize_tool_call_candidate(candidate_dict, tool_names)
if checked_candidate is not None:
if start_pos is None:
start_pos = match.start()
matches.append(checked_candidate)
# last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags
if len(matches) == 0:
try:
candidate = answer
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
if not candidate.strip().startswith("["):
candidate = "[" + candidate + "]"
# parse the candidate JSON into a dictionary
candidates = json.loads(candidate)
if not isinstance(candidates, list):
candidates = [candidates]
for candidate_dict in candidates:
checked_candidate = check_and_sanitize_tool_call_candidate(candidate_dict, tool_names)
if checked_candidate is not None:
matches.append(checked_candidate)
except json.JSONDecodeError:
# Ignore invalid JSON silently
pass
return _return(matches, start_pos)

71
modules/tool_use.py Normal file
View file

@ -0,0 +1,71 @@
import importlib.util
import json
from modules import shared
from modules.logging_colors import logger
from modules.utils import natural_keys, sanitize_filename
def get_available_tools():
"""Return sorted list of tool script names from user_data/tools/*.py."""
tools_dir = shared.user_data_dir / 'tools'
tools_dir.mkdir(parents=True, exist_ok=True)
return sorted((p.stem for p in tools_dir.glob('*.py')), key=natural_keys)
def load_tools(selected_names):
"""
Import selected tool scripts and return their definitions and executors.
Returns (tool_defs, executors) where:
- tool_defs: list of OpenAI-format tool dicts
- executors: dict mapping function_name -> execute callable
"""
tool_defs = []
executors = {}
for name in selected_names:
name = sanitize_filename(name)
if not name:
continue
path = shared.user_data_dir / 'tools' / f'{name}.py'
if not path.exists():
continue
try:
spec = importlib.util.spec_from_file_location(f"tool_{name}", str(path))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
except Exception:
logger.exception(f'Failed to load tool script "{name}"')
continue
tool_def = getattr(module, 'tool', None)
execute_fn = getattr(module, 'execute', None)
if tool_def is None or execute_fn is None:
logger.warning(f'Tool "{name}" is missing a "tool" dict or "execute" function.')
continue
func_name = tool_def.get('function', {}).get('name', name)
if func_name in executors:
logger.warning(f'Tool "{name}" declares function name "{func_name}" which conflicts with an already loaded tool. Skipping.')
continue
tool_defs.append(tool_def)
executors[func_name] = execute_fn
return tool_defs, executors
def execute_tool(func_name, arguments, executors):
"""Execute a tool by function name. Returns result as a JSON string."""
fn = executors.get(func_name)
if fn is None:
return json.dumps({"error": f"Unknown tool: {func_name}"})
try:
if isinstance(arguments, str):
arguments = json.loads(arguments)
result = fn(arguments)
return json.dumps(result) if not isinstance(result, str) else result
except Exception as e:
logger.exception(f'Tool "{func_name}" execution failed')
return json.dumps({"error": str(e)})

View file

@ -310,6 +310,11 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
# == Input validation / processing ==
yield "Preparing the input..."
if shared.args.loader == 'llama.cpp':
yield "Error: LoRA training requires a model loaded with the Transformers loader. GGUF models are not supported for training."
return
lora_file_path = clean_path(None, lora_name)
if lora_file_path.strip() == '':
yield "Missing or invalid LoRA file name input."

View file

@ -65,14 +65,16 @@ class LogprobProcessor(LogitsProcessor):
def __init__(self, logprobs=None):
self.logprobs = logprobs
self.token_alternatives = {}
self.token_alternatives_history = []
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
if self.logprobs is not None: # 0-5
log_e_probabilities = F.log_softmax(logits, dim=1)
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs)
top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]]
top_probs = [float(x) for x in top_values[0]]
self.token_alternatives = dict(zip(top_tokens, top_probs))
self.token_alternatives_history.append(self.token_alternatives)
return logits
@ -134,8 +136,6 @@ def load_model_HF(model_name):
shared.args.load_in_4bit,
shared.args.disk,
shared.args.cpu_memory is not None,
shared.args.compress_pos_emb > 1,
shared.args.alpha_value > 1,
])
# Load the model without any special settings
@ -198,11 +198,6 @@ def load_model_HF(model_name):
if shared.args.disk:
params['offload_folder'] = str(Path(shared.args.disk_cache_dir))
if shared.args.compress_pos_emb > 1:
params['rope_scaling'] = {'type': 'linear', 'factor': shared.args.compress_pos_emb}
elif shared.args.alpha_value > 1:
params['rope_scaling'] = {'type': 'dynamic', 'factor': shared.args.alpha_value}
logger.info("TRANSFORMERS_PARAMS=")
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params)
print()

View file

@ -120,58 +120,8 @@ else:
def list_model_elements():
elements = [
'filter_by_loader',
'loader',
'cpu_memory',
'gpu_layers',
'fit_target',
'cpu_moe',
'threads',
'threads_batch',
'batch_size',
'ubatch_size',
'ctx_size',
'cache_type',
'tensor_split',
'extra_flags',
'streaming_llm',
'gpu_split',
'alpha_value',
'rope_freq_base',
'compress_pos_emb',
'compute_dtype',
'quant_type',
'load_in_8bit',
'load_in_4bit',
'attn_implementation',
'cpu',
'disk',
'row_split',
'no_kv_offload',
'no_mmap',
'mlock',
'numa',
'parallel',
'use_double_quant',
'bf16',
'enable_tp',
'tp_backend',
'cfg_cache',
'no_use_fast',
'model_draft',
'draft_max',
'gpu_layers_draft',
'device_draft',
'ctx_size_draft',
'spec_type',
'spec_ngram_size_n',
'spec_ngram_size_m',
'spec_ngram_min_hits',
'mmproj',
]
return elements
from modules.loaders import list_model_elements
return list_model_elements()
def list_interface_input_elements():
@ -249,6 +199,7 @@ def list_interface_input_elements():
'unique_id',
'textbox',
'start_with',
'selected_tools',
'mode',
'chat_style',
'chat-instruct_command',
@ -353,12 +304,16 @@ def save_settings(state, preset, extensions_list, show_controls, theme_state, ma
if k in shared.settings and k not in exclude:
output[k] = state[k]
output['preset'] = preset
if preset:
output['preset'] = preset
output['prompt-notebook'] = state['prompt_menu-default'] if state['show_two_notebook_columns'] else state['prompt_menu-notebook']
output['character'] = state['character_menu']
if 'user_menu' in state and state['user_menu']:
if state.get('character_menu'):
output['character'] = state['character_menu']
if state.get('user_menu'):
output['user'] = state['user_menu']
output['seed'] = int(output['seed'])
output['custom_stopping_strings'] = output.get('custom_stopping_strings') or ''
output['custom_token_bans'] = output.get('custom_token_bans') or ''
output['show_controls'] = show_controls
output['dark_theme'] = True if theme_state == 'dark' else False
output.pop('instruction_template_str')
@ -470,6 +425,7 @@ def setup_auto_save():
'user_bio',
'custom_system_message',
'chat_template_str',
'selected_tools',
# Parameters tab (ui_parameters.py) - Generation parameters
'preset_menu',
@ -520,7 +476,6 @@ def setup_auto_save():
'skip_special_tokens',
'stream',
'static_cache',
'truncation_length',
'seed',
'sampler_priority',
'custom_stopping_strings',

View file

@ -28,7 +28,8 @@ def create_ui():
shared.gradio['branch_chat'] = gr.Button('Branch', elem_classes=['refresh-button', 'refresh-button-medium'], elem_id='Branch', interactive=not mu)
shared.gradio['rename_chat'] = gr.Button('Rename', elem_classes=['refresh-button', 'refresh-button-medium'], interactive=not mu)
shared.gradio['delete_chat'] = gr.Button('🗑️', visible=False, elem_classes='refresh-button', interactive=not mu, elem_id='delete_chat')
shared.gradio['Start new chat'] = gr.Button('New chat', elem_classes=['refresh-button', 'refresh-button-medium', 'focus-on-chat-input'])
shared.gradio['Start new chat'] = gr.Button('New chat', elem_classes=['refresh-button', 'refresh-button-medium', 'focus-on-chat-input'], elem_id='new-chat-btn')
shared.gradio['Start incognito chat'] = gr.Button('Incognito chat', visible=False, elem_id='incognito-chat-btn')
shared.gradio['branch_index'] = gr.Number(value=-1, precision=0, visible=False, elem_id="Branch-index", interactive=True)
shared.gradio['search_chat'] = gr.Textbox(placeholder='Search chats...', max_lines=1, elem_id='search_chat')
@ -91,6 +92,21 @@ def create_ui():
gr.HTML("<div class='sidebar-vertical-separator'></div>")
from modules.tool_use import get_available_tools
shared.gradio['selected_tools'] = gr.CheckboxGroup(choices=get_available_tools(), value=shared.settings.get('selected_tools', []), label='Tools', info='Functions the model can call during generation.', elem_id='tools-group')
shared.gradio['tools_refresh'] = gr.Button('Refresh list', elem_id='tools-refresh-btn', visible=False)
shared.gradio['tools_refresh'].click(fn=lambda: gr.update(choices=get_available_tools()), inputs=[], outputs=[shared.gradio['selected_tools']])
def sync_web_tools(selected):
if 'web_search' in selected and 'fetch_webpage' not in selected and 'fetch_webpage' in get_available_tools():
selected.append('fetch_webpage')
return gr.update(value=selected)
shared.gradio['selected_tools'].change(fn=sync_web_tools, inputs=[shared.gradio['selected_tools']], outputs=[shared.gradio['selected_tools']], show_progress=False)
gr.HTML("<div class='sidebar-vertical-separator'></div>")
with gr.Row():
shared.gradio['mode'] = gr.Radio(choices=['instruct', 'chat-instruct', 'chat'], value=None, label='Mode', info='In instruct and chat-instruct modes, the template under Parameters > Instruction template is used.', elem_id='chat-mode')
@ -275,6 +291,10 @@ def create_event_handlers():
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.handle_start_new_chat_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False)
shared.gradio['Start incognito chat'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.handle_start_incognito_chat_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False)
shared.gradio['delete_chat-confirm'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.handle_delete_chat_confirm_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False)

View file

@ -728,6 +728,8 @@ def generate_prompt_variation(state):
variation = variation.rsplit("</think>", 1)[1]
elif "<|start|>assistant<|channel|>final<|message|>" in variation:
variation = variation.rsplit("<|start|>assistant<|channel|>final<|message|>", 1)[1]
elif "<|channel|>final<|message|>" in variation:
variation = variation.rsplit("<|channel|>final<|message|>", 1)[1]
elif "</seed:think>" in variation:
variation = variation.rsplit("</seed:think>", 1)[1]

View file

@ -42,11 +42,11 @@ def create_ui():
with gr.Row():
with gr.Column():
shared.gradio['gpu_layers'] = gr.Slider(label="gpu-layers", minimum=-1, maximum=get_initial_gpu_layers_max(), step=1, value=shared.args.gpu_layers, info='Number of layers to offload to the GPU. -1 = auto.')
shared.gradio['ctx_size'] = gr.Slider(label='ctx-size', minimum=0, maximum=131072, step=256, value=shared.args.ctx_size, info='Context length. llama.cpp: 0 = auto if gpu-layers is also -1. Common values: 4096, 8192, 16384, 32768, 65536, 131072.')
shared.gradio['ctx_size'] = gr.Slider(label='ctx-size', minimum=0, maximum=1048576, step=1024, value=shared.args.ctx_size, info='Context length. 0 = auto for llama.cpp (requires gpu-layers=-1), 8192 for other loaders. Common values: 4096, 8192, 16384, 32768, 65536, 131072.')
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
shared.gradio['attn_implementation'] = gr.Dropdown(label="attn-implementation", choices=['sdpa', 'eager', 'flash_attention_2'], value=shared.args.attn_implementation, info='Attention implementation.')
shared.gradio['cache_type'] = gr.Dropdown(label="cache-type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q7', 'q6', 'q5', 'q4', 'q3', 'q2'], value=shared.args.cache_type, allow_custom_value=True, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8. For ExLlamaV3, you can type custom combinations for separate k/v bits (e.g. q4_q8).')
shared.gradio['fit_target'] = gr.Textbox(label='fit-target', value=shared.args.fit_target, info='Target VRAM margin per device for auto GPU layers (MiB). Comma-separated list for multiple devices. Default: 1024.')
shared.gradio['fit_target'] = gr.Textbox(label='fit-target', value=shared.args.fit_target, info='Target VRAM margin per device for auto GPU layers (MiB). Comma-separated list for multiple devices.')
shared.gradio['tp_backend'] = gr.Dropdown(label="tp-backend", choices=['native', 'nccl'], value=shared.args.tp_backend, info='The backend for tensor parallelism.')
with gr.Column():
@ -100,9 +100,6 @@ def create_ui():
shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')
shared.gradio['extra_flags'] = gr.Textbox(label='extra-flags', info='Additional flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"', value=shared.args.extra_flags)
shared.gradio['cpu_memory'] = gr.Number(label="Maximum CPU memory in GiB. Use this for CPU offloading.", value=shared.args.cpu_memory)
shared.gradio['alpha_value'] = gr.Number(label='alpha_value', value=shared.args.alpha_value, precision=2, info='Positional embeddings alpha factor for NTK RoPE scaling. Recommended values (NTKv1): 1.75 for 1.5x context, 2.5 for 2x context. Use either this or compress_pos_emb, not both.')
shared.gradio['rope_freq_base'] = gr.Number(label='rope_freq_base', value=shared.args.rope_freq_base, precision=0, info='Positional embeddings frequency base for NTK RoPE scaling. Related to alpha_value by rope_freq_base = 10000 * alpha_value ^ (64 / 63). 0 = from model.')
shared.gradio['compress_pos_emb'] = gr.Number(label='compress_pos_emb', value=shared.args.compress_pos_emb, precision=2, info='Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.')
shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype, info='Used by load-in-4bit.')
shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type, info='Used by load-in-4bit.')
@ -388,7 +385,11 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
def update_truncation_length(current_length, state):
if 'loader' in state:
if state['loader'].lower().startswith('exllama') or state['loader'] == 'llama.cpp':
return state['ctx_size']
if state['ctx_size'] > 0:
return state['ctx_size']
# ctx_size == 0 means auto: use the actual value from the server
return shared.settings['truncation_length']
return current_length

View file

@ -37,10 +37,10 @@ def create_ui():
shared.gradio['dynamic_temperature'] = gr.Checkbox(value=shared.settings['dynamic_temperature'], label='dynamic_temperature')
gr.Markdown('## Curve cutoff')
shared.gradio['min_p'] = gr.Slider(0.0, 1.0, value=shared.settings['min_p'], step=0.01, label='min_p')
shared.gradio['top_n_sigma'] = gr.Slider(0.0, 5.0, value=shared.settings['top_n_sigma'], step=0.01, label='top_n_sigma')
shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=shared.settings['top_p'], step=0.01, label='top_p')
shared.gradio['top_k'] = gr.Slider(0, 200, value=shared.settings['top_k'], step=1, label='top_k')
shared.gradio['min_p'] = gr.Slider(0.0, 1.0, value=shared.settings['min_p'], step=0.01, label='min_p')
shared.gradio['top_n_sigma'] = gr.Slider(0.0, 5.0, value=shared.settings['top_n_sigma'], step=0.01, label='top_n_sigma')
shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=shared.settings['typical_p'], step=0.01, label='typical_p')
shared.gradio['xtc_threshold'] = gr.Slider(0, 0.5, value=shared.settings['xtc_threshold'], step=0.01, label='xtc_threshold', info='If 2 or more tokens have probability above this threshold, consider removing all but the last one.')
shared.gradio['xtc_probability'] = gr.Slider(0, 1, value=shared.settings['xtc_probability'], step=0.01, label='xtc_probability', info='Probability that the removal will actually happen. 0 disables the sampler. 1 makes it always happen.')
@ -73,7 +73,7 @@ def create_ui():
gr.Markdown('## Other options')
shared.gradio['do_sample'] = gr.Checkbox(value=shared.settings['do_sample'], label='do_sample')
shared.gradio['temperature_last'] = gr.Checkbox(value=shared.settings['temperature_last'], label='temperature_last', info='Moves temperature/dynamic temperature/quadratic sampling to the end of the sampler stack, ignoring their positions in "Sampler priority".')
shared.gradio['sampler_priority'] = gr.Textbox(value=shared.settings['sampler_priority'], lines=10, label='Sampler priority', info='Parameter names separated by new lines or commas.', elem_classes=['add_scrollbar'])
shared.gradio['sampler_priority'] = gr.DragDrop(value=shared.settings['sampler_priority'], label='Sampler priority', info='Parameter names separated by new lines or commas.', elem_classes=['add_scrollbar'])
shared.gradio['dry_sequence_breakers'] = gr.Textbox(value=shared.settings['dry_sequence_breakers'], label='dry_sequence_breakers', info='Tokens across which sequence matching is not continued. Specified as a comma-separated list of quoted strings.')
with gr.Column():

View file

@ -1,11 +1,12 @@
import concurrent.futures
import html
import ipaddress
import random
import re
import urllib.request
import socket
from concurrent.futures import as_completed
from datetime import datetime
from urllib.parse import quote_plus
from urllib.parse import parse_qs, quote_plus, urljoin, urlparse
import requests
@ -13,34 +14,60 @@ from modules import shared
from modules.logging_colors import logger
def _validate_url(url):
"""Validate that a URL is safe to fetch (not targeting private/internal networks)."""
parsed = urlparse(url)
if parsed.scheme not in ('http', 'https'):
raise ValueError(f"Unsupported URL scheme: {parsed.scheme}")
hostname = parsed.hostname
if not hostname:
raise ValueError("No hostname in URL")
# Resolve hostname and check all returned addresses
try:
for family, _, _, _, sockaddr in socket.getaddrinfo(hostname, None):
ip = ipaddress.ip_address(sockaddr[0])
if not ip.is_global:
raise ValueError(f"Access to non-public address {ip} is blocked")
except socket.gaierror:
raise ValueError(f"Could not resolve hostname: {hostname}")
def get_current_timestamp():
"""Returns the current time in 24-hour format"""
return datetime.now().strftime('%b %d, %Y %H:%M')
def download_web_page(url, timeout=10):
def download_web_page(url, timeout=10, include_links=False):
"""
Download a web page and convert its HTML content to structured Markdown text.
Download a web page and extract its main content as Markdown text.
"""
import html2text
import trafilatura
try:
_validate_url(url)
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
response = requests.get(url, headers=headers, timeout=timeout)
response.raise_for_status() # Raise an exception for bad status codes
max_redirects = 5
for _ in range(max_redirects):
response = requests.get(url, headers=headers, timeout=timeout, allow_redirects=False)
if response.is_redirect and 'Location' in response.headers:
url = urljoin(url, response.headers['Location'])
_validate_url(url)
else:
break
# Initialize the HTML to Markdown converter
h = html2text.HTML2Text()
h.body_width = 0
h.ignore_images = True
h.ignore_links = True
response.raise_for_status()
# Convert the HTML to Markdown
markdown_text = h.handle(response.text)
return markdown_text
result = trafilatura.extract(
response.text,
include_links=include_links,
output_format='markdown',
url=url
)
return result or ""
except requests.exceptions.RequestException as e:
logger.error(f"Error downloading {url}: {e}")
return ""
@ -49,8 +76,8 @@ def download_web_page(url, timeout=10):
return ""
def perform_web_search(query, num_pages=3, max_workers=5, timeout=10):
"""Perform web search and return results with content"""
def perform_web_search(query, num_pages=3, max_workers=5, timeout=10, fetch_content=True):
"""Perform web search and return results, optionally with page content"""
try:
search_url = f"https://html.duckduckgo.com/html/?q={quote_plus(query)}"
@ -59,25 +86,41 @@ def perform_web_search(query, num_pages=3, max_workers=5, timeout=10):
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36"
]
response_text = ""
req = urllib.request.Request(search_url, headers={'User-Agent': random.choice(agents)})
with urllib.request.urlopen(req, timeout=timeout) as response:
response_text = response.read().decode('utf-8')
response = requests.get(search_url, headers={'User-Agent': random.choice(agents)}, timeout=timeout)
response.raise_for_status()
response_text = response.text
# Extract results with regex
titles = re.findall(r'<a[^>]*class="[^"]*result__a[^"]*"[^>]*>(.*?)</a>', response_text, re.DOTALL)
urls = re.findall(r'<a[^>]*class="[^"]*result__url[^"]*"[^>]*>(.*?)</a>', response_text, re.DOTALL)
# Extract results - title and URL come from the same <a class="result__a"> element
result_links = re.findall(r'<a[^>]*class="[^"]*result__a[^"]*"[^>]*>(.*?)</a>', response_text, re.DOTALL)
result_tags = re.findall(r'<a([^>]*class="[^"]*result__a[^"]*"[^>]*)>', response_text, re.DOTALL)
# Prepare download tasks
download_tasks = []
for i in range(min(len(titles), len(urls), num_pages)):
url = f"https://{urls[i].strip()}"
title = re.sub(r'<[^>]+>', '', titles[i]).strip()
title = html.unescape(title)
download_tasks.append((url, title, i))
for i, (tag_attrs, raw_title) in enumerate(zip(result_tags, result_links)):
if num_pages is not None and i >= num_pages:
break
# Extract href and resolve the actual URL from DuckDuckGo's redirect link
href_match = re.search(r'href="([^"]*)"', tag_attrs)
if not href_match:
continue
uddg = parse_qs(urlparse(html.unescape(href_match.group(1))).query).get('uddg', [''])[0]
if not uddg:
continue
title = html.unescape(re.sub(r'<[^>]+>', '', raw_title).strip())
download_tasks.append((uddg, title, len(download_tasks)))
search_results = [None] * len(download_tasks) # Pre-allocate to maintain order
if not fetch_content:
for url, title, index in download_tasks:
search_results[index] = {
'title': title,
'url': url,
'content': ''
}
return search_results
# Download pages in parallel
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all download tasks

View file

@ -91,7 +91,7 @@ def get_gpu_choice():
"What is your GPU?",
{
'A': 'NVIDIA',
'B': 'AMD - Linux/macOS only, requires ROCm 6.4',
'B': 'AMD - Linux only, ROCm 7.2',
'C': 'Apple M Series',
'D': 'Intel Arc (beta)',
'N': 'CPU mode'
@ -116,14 +116,12 @@ def get_pytorch_install_command(gpu_choice):
if gpu_choice == "NVIDIA_CUDA128":
return base_cmd + "--index-url https://download.pytorch.org/whl/cu128" + pypi_fallback
elif gpu_choice == "AMD":
return base_cmd + "--index-url https://download.pytorch.org/whl/rocm6.4" + pypi_fallback
py_tag = f"cp{PYTHON_VERSION.replace('.', '')}"
return f"python -m pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-{TORCH_VERSION}%2Brocm7.2.0.lw.git7e1940d4-{py_tag}-{py_tag}-linux_x86_64.whl"
elif gpu_choice in ["APPLE", "NONE"]:
return base_cmd + "--index-url https://download.pytorch.org/whl/cpu" + pypi_fallback
elif gpu_choice == "INTEL":
if is_linux():
return "python -m pip install torch==2.1.0a0 intel-extension-for-pytorch==2.1.10+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
else:
return "python -m pip install torch==2.1.0a0 intel-extension-for-pytorch==2.1.10 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
return base_cmd + "--index-url https://download.pytorch.org/whl/xpu"
else:
return base_cmd
@ -136,12 +134,12 @@ def get_pytorch_update_command(gpu_choice):
if gpu_choice == "NVIDIA_CUDA128":
return f"{base_cmd}--index-url https://download.pytorch.org/whl/cu128" + pypi_fallback
elif gpu_choice == "AMD":
return f"{base_cmd}--index-url https://download.pytorch.org/whl/rocm6.4" + pypi_fallback
py_tag = f"cp{PYTHON_VERSION.replace('.', '')}"
return f"python -m pip install --upgrade https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-{TORCH_VERSION}%2Brocm7.2.0.lw.git7e1940d4-{py_tag}-{py_tag}-linux_x86_64.whl"
elif gpu_choice in ["APPLE", "NONE"]:
return f"{base_cmd}--index-url https://download.pytorch.org/whl/cpu" + pypi_fallback
elif gpu_choice == "INTEL":
intel_extension = "intel-extension-for-pytorch==2.1.10+xpu" if is_linux() else "intel-extension-for-pytorch==2.1.10"
return f"{base_cmd}{intel_extension} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
return f"{base_cmd}--index-url https://download.pytorch.org/whl/xpu"
else:
return base_cmd
@ -196,6 +194,8 @@ def run_cmd(cmd, assert_success=False, environment=False, capture_output=False,
if environment:
if is_windows():
conda_bat_path = os.path.join(script_dir, "installer_files", "conda", "condabin", "conda.bat")
python_path = os.path.join(conda_env_path, "python.exe")
cmd = cmd.replace("python ", f'"{python_path}" ')
cmd = f'"{conda_bat_path}" activate "{conda_env_path}" >nul && {cmd}'
else:
conda_sh_path = os.path.join(script_dir, "installer_files", "conda", "etc", "profile.d", "conda.sh")
@ -270,7 +270,7 @@ def update_pytorch_and_python():
def clean_outdated_pytorch_cuda_dependencies():
patterns = ["cu121", "cu122", "torch2.4", "torch2.6", "torch2.7", "torchvision", "torchaudio"]
patterns = ["cu121", "cu122", "rocm6", "torch2.4", "torch2.6", "torch2.7", "torchvision", "torchaudio"]
result = run_cmd("python -m pip list --format=freeze", capture_output=True, environment=True)
matching_packages = []
@ -316,13 +316,6 @@ def install_webui():
install_pytorch = get_pytorch_install_command(gpu_choice)
run_cmd(f"conda install -y ninja git && {install_pytorch}", assert_success=True, environment=True)
if gpu_choice == "INTEL":
# Install oneAPI dependencies via conda
print_big_message("Installing Intel oneAPI runtime libraries.")
run_cmd("conda install -y -c https://software.repos.intel.com/python/conda/ -c conda-forge dpcpp-cpp-rt=2024.0 mkl-dpcpp=2024.0", environment=True)
# Install libuv required by Intel-patched torch
run_cmd("conda install -y libuv", environment=True)
# Install the webui requirements
update_requirements(initial_installation=True, pull=False)
@ -365,8 +358,10 @@ def update_requirements(initial_installation=False, pull=True):
current_commit = get_current_commit()
wheels_changed = not os.path.exists(state_file)
installed_wheels = set()
if not wheels_changed:
state = load_state()
installed_wheels = set(state.get('installed_wheels', []))
if 'wheels_changed' in state or state.get('last_installed_commit') != current_commit:
wheels_changed = True
@ -431,9 +426,17 @@ def update_requirements(initial_installation=False, pull=True):
# Prepare the requirements file
textgen_requirements = open(requirements_file).read().splitlines()
all_whl_lines = [line.strip() for line in textgen_requirements if '.whl' in line]
if not initial_installation and not wheels_changed:
textgen_requirements = [line for line in textgen_requirements if '.whl' not in line]
if not initial_installation:
if installed_wheels:
# Per-wheel comparison: only re-download wheels that changed
textgen_requirements = [
line for line in textgen_requirements
if '.whl' not in line or line.strip() not in installed_wheels
]
elif not wheels_changed:
textgen_requirements = [line for line in textgen_requirements if '.whl' not in line]
with open('temp_requirements.txt', 'w') as file:
file.write('\n'.join(textgen_requirements))
@ -452,6 +455,7 @@ def update_requirements(initial_installation=False, pull=True):
# Save state after successful installation
state = load_state()
state['last_installed_commit'] = current_commit
state['installed_wheels'] = all_whl_lines
state.pop('wheels_changed', None)
save_state(state)

View file

@ -2,11 +2,10 @@ accelerate==1.12.*
audioop-lts<1.0; python_version >= "3.13"
bitsandbytes==0.49.*
datasets
diffusers==0.36.*
diffusers==0.37.*
einops
fastapi==0.112.4
flash-linear-attention==0.4.*
html2text==2025.4.15
huggingface-hub==1.5.*
jinja2==3.1.6
markdown
@ -25,14 +24,15 @@ scipy
sentencepiece
tensorboard
torchao==0.15.*
trafilatura==2.0.0
transformers==5.3.*
triton-windows==3.5.1.post24; platform_system == "Windows"
tqdm
wandb
# Gradio
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
# API
flask_cloudflared==0.0.15
@ -40,9 +40,9 @@ sse-starlette==1.6.5
tiktoken
# CUDA wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/turboderp-org/exllamav3/releases/download/v0.0.23/exllamav3-0.0.23+cu128.torch2.9.0-cp313-cp313-win_amd64.whl; platform_system == "Windows" and python_version == "3.13"
https://github.com/turboderp-org/exllamav3/releases/download/v0.0.23/exllamav3-0.0.23+cu128.torch2.9.0-cp313-cp313-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.13"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/turboderp-org/exllamav3/releases/download/v0.0.25/exllamav3-0.0.25+cu128.torch2.9.0-cp313-cp313-win_amd64.whl; platform_system == "Windows" and python_version == "3.13"
https://github.com/turboderp-org/exllamav3/releases/download/v0.0.25/exllamav3-0.0.25+cu128.torch2.9.0-cp313-cp313-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.13"
https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp313-cp313-win_amd64.whl; platform_system == "Windows" and python_version == "3.13"
https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp313-cp313-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.13"

View file

@ -1,10 +1,9 @@
accelerate==1.12.*
audioop-lts<1.0; python_version >= "3.13"
datasets
diffusers==0.36.*
diffusers==0.37.*
einops
fastapi==0.112.4
html2text==2025.4.15
huggingface-hub==1.5.*
jinja2==3.1.6
markdown
@ -25,11 +24,12 @@ tensorboard
torchao==0.15.*
transformers==5.3.*
tqdm
trafilatura==2.0.0
wandb
# Gradio
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
# API
flask_cloudflared==0.0.15
@ -37,5 +37,5 @@ sse-starlette==1.6.5
tiktoken
# AMD wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+rocm6.4-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+rocm6.4-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+rocm7.2-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+rocm7.2-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"

View file

@ -1,10 +1,9 @@
accelerate==1.12.*
audioop-lts<1.0; python_version >= "3.13"
datasets
diffusers==0.36.*
diffusers==0.37.*
einops
fastapi==0.112.4
html2text==2025.4.15
huggingface-hub==1.5.*
jinja2==3.1.6
markdown
@ -25,11 +24,12 @@ tensorboard
torchao==0.15.*
transformers==5.3.*
tqdm
trafilatura==2.0.0
wandb
# Gradio
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
# API
flask_cloudflared==0.0.15
@ -37,4 +37,4 @@ sse-starlette==1.6.5
tiktoken
# Mac wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0-py3-none-macosx_13_0_x86_64.whl; platform_system == "Darwin"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0-py3-none-macosx_13_0_x86_64.whl; platform_system == "Darwin"

View file

@ -1,10 +1,9 @@
accelerate==1.12.*
audioop-lts<1.0; python_version >= "3.13"
datasets
diffusers==0.36.*
diffusers==0.37.*
einops
fastapi==0.112.4
html2text==2025.4.15
huggingface-hub==1.5.*
jinja2==3.1.6
markdown
@ -25,11 +24,12 @@ tensorboard
torchao==0.15.*
transformers==5.3.*
tqdm
trafilatura==2.0.0
wandb
# Gradio
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
# API
flask_cloudflared==0.0.15
@ -37,4 +37,4 @@ sse-starlette==1.6.5
tiktoken
# Mac wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0-py3-none-macosx_13_0_arm64.whl; platform_system == "Darwin"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0-py3-none-macosx_13_0_arm64.whl; platform_system == "Darwin"

View file

@ -1,10 +1,9 @@
accelerate==1.12.*
audioop-lts<1.0; python_version >= "3.13"
datasets
diffusers==0.36.*
diffusers==0.37.*
einops
fastapi==0.112.4
html2text==2025.4.15
huggingface-hub==1.5.*
jinja2==3.1.6
markdown
@ -25,11 +24,12 @@ tensorboard
torchao==0.15.*
transformers==5.3.*
tqdm
trafilatura==2.0.0
wandb
# Gradio
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
# API
flask_cloudflared==0.0.15
@ -37,5 +37,5 @@ sse-starlette==1.6.5
tiktoken
# llama.cpp (CPU only)
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cpu-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cpu-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cpu-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cpu-py3-none-win_amd64.whl; platform_system == "Windows"

View file

@ -1,10 +1,9 @@
accelerate==1.12.*
audioop-lts<1.0; python_version >= "3.13"
datasets
diffusers==0.36.*
diffusers==0.37.*
einops
fastapi==0.112.4
html2text==2025.4.15
huggingface-hub==1.5.*
jinja2==3.1.6
markdown
@ -25,11 +24,12 @@ tensorboard
torchao==0.15.*
transformers==5.3.*
tqdm
trafilatura==2.0.0
wandb
# Gradio
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
# API
flask_cloudflared==0.0.15

View file

@ -1,6 +1,5 @@
audioop-lts<1.0; python_version >= "3.13"
fastapi==0.112.4
html2text==2025.4.15
huggingface-hub==1.5.*
jinja2==3.1.6
markdown
@ -11,11 +10,12 @@ python-docx==1.1.2
pyyaml
requests
rich
trafilatura==2.0.0
tqdm
# Gradio
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
# API
flask_cloudflared==0.0.15
@ -23,5 +23,5 @@ sse-starlette==1.6.5
tiktoken
# CUDA wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"

View file

@ -1,6 +1,5 @@
audioop-lts<1.0; python_version >= "3.13"
fastapi==0.112.4
html2text==2025.4.15
huggingface-hub==1.5.*
jinja2==3.1.6
markdown
@ -11,11 +10,12 @@ python-docx==1.1.2
pyyaml
requests
rich
trafilatura==2.0.0
tqdm
# Gradio
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
# API
flask_cloudflared==0.0.15
@ -23,5 +23,5 @@ sse-starlette==1.6.5
tiktoken
# AMD wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+rocm6.4-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+rocm6.4-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+rocm7.2-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+rocm7.2-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"

View file

@ -1,6 +1,5 @@
audioop-lts<1.0; python_version >= "3.13"
fastapi==0.112.4
html2text==2025.4.15
huggingface-hub==1.5.*
jinja2==3.1.6
markdown
@ -11,11 +10,12 @@ python-docx==1.1.2
pyyaml
requests
rich
trafilatura==2.0.0
tqdm
# Gradio
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
# API
flask_cloudflared==0.0.15
@ -23,4 +23,4 @@ sse-starlette==1.6.5
tiktoken
# Mac wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0-py3-none-macosx_13_0_x86_64.whl; platform_system == "Darwin"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0-py3-none-macosx_13_0_x86_64.whl; platform_system == "Darwin"

View file

@ -1,6 +1,5 @@
audioop-lts<1.0; python_version >= "3.13"
fastapi==0.112.4
html2text==2025.4.15
huggingface-hub==1.5.*
jinja2==3.1.6
markdown
@ -11,11 +10,12 @@ python-docx==1.1.2
pyyaml
requests
rich
trafilatura==2.0.0
tqdm
# Gradio
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
# API
flask_cloudflared==0.0.15
@ -23,4 +23,4 @@ sse-starlette==1.6.5
tiktoken
# Mac wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0-py3-none-macosx_13_0_arm64.whl; platform_system == "Darwin"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0-py3-none-macosx_13_0_arm64.whl; platform_system == "Darwin"

View file

@ -1,6 +1,5 @@
audioop-lts<1.0; python_version >= "3.13"
fastapi==0.112.4
html2text==2025.4.15
huggingface-hub==1.5.*
jinja2==3.1.6
markdown
@ -11,11 +10,12 @@ python-docx==1.1.2
pyyaml
requests
rich
trafilatura==2.0.0
tqdm
# Gradio
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
# API
flask_cloudflared==0.0.15
@ -23,5 +23,5 @@ sse-starlette==1.6.5
tiktoken
# llama.cpp (CPU only)
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cpu-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cpu-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cpu-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cpu-py3-none-win_amd64.whl; platform_system == "Windows"

View file

@ -1,6 +1,5 @@
audioop-lts<1.0; python_version >= "3.13"
fastapi==0.112.4
html2text==2025.4.15
huggingface-hub==1.5.*
jinja2==3.1.6
markdown
@ -11,11 +10,12 @@ python-docx==1.1.2
pyyaml
requests
rich
trafilatura==2.0.0
tqdm
# Gradio
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
# API
flask_cloudflared==0.0.15
@ -23,5 +23,5 @@ sse-starlette==1.6.5
tiktoken
# CUDA wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu131-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu131-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu131-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu131-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"

View file

@ -1,6 +1,5 @@
audioop-lts<1.0; python_version >= "3.13"
fastapi==0.112.4
html2text==2025.4.15
huggingface-hub==1.5.*
jinja2==3.1.6
markdown
@ -11,11 +10,12 @@ python-docx==1.1.2
pyyaml
requests
rich
trafilatura==2.0.0
tqdm
# Gradio
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
# API
flask_cloudflared==0.0.15

View file

@ -1,6 +1,5 @@
audioop-lts<1.0; python_version >= "3.13"
fastapi==0.112.4
html2text==2025.4.15
huggingface-hub==1.5.*
jinja2==3.1.6
markdown
@ -11,11 +10,12 @@ python-docx==1.1.2
pyyaml
requests
rich
trafilatura==2.0.0
tqdm
# Gradio
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
# API
flask_cloudflared==0.0.15
@ -23,5 +23,5 @@ sse-starlette==1.6.5
tiktoken
# Vulkan wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+vulkan-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+vulkan-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+vulkan-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+vulkan-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"

View file

@ -1,58 +1,20 @@
import os
import shutil
import warnings
from pathlib import Path
from modules import shared, ui # ui must be imported early to avoid circular imports
from modules.image_models import load_image_model
from modules.logging_colors import logger
from modules.prompts import load_prompt
# Set up Gradio temp directory path
gradio_temp_path = shared.user_data_dir / 'cache' / 'gradio'
shutil.rmtree(gradio_temp_path, ignore_errors=True)
gradio_temp_path.mkdir(parents=True, exist_ok=True)
# Set environment variables
os.environ.update({
'GRADIO_ANALYTICS_ENABLED': 'False',
'BITSANDBYTES_NOWELCOME': '1',
'GRADIO_TEMP_DIR': str(gradio_temp_path)
})
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated')
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict')
warnings.filterwarnings('ignore', category=UserWarning, message='The value passed into gr.Dropdown()')
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_names" has conflict')
import gradio as gr
import os
import signal
import sys
import time
import warnings
from functools import partial
from pathlib import Path
from threading import Lock, Thread
import yaml
from modules import shared, utils
from modules.image_models import load_image_model
from modules.logging_colors import logger
from modules.prompts import load_prompt
import modules.extensions as extensions_module
from modules import (
training,
ui,
ui_chat,
ui_default,
ui_file_saving,
ui_image_generation,
ui_model_menu,
ui_notebook,
ui_parameters,
ui_session,
utils
)
from modules.chat import generate_pfp_cache
from modules.extensions import apply_extensions
from modules.LoRA import add_lora_to_model
from modules.models import load_model, unload_model_if_idle
from modules.models_settings import (
@ -61,10 +23,20 @@ from modules.models_settings import (
update_model_parameters
)
from modules.shared import do_cmd_flags_warnings
from modules.utils import gradio
os.environ['BITSANDBYTES_NOWELCOME'] = '1'
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated')
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict')
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_names" has conflict')
def signal_handler(sig, frame):
# On second Ctrl+C, force an immediate exit
signal.signal(signal.SIGINT, signal.SIG_DFL)
signal.signal(signal.SIGTERM, signal.SIG_DFL)
logger.info("Received Ctrl+C. Shutting down Text Generation Web UI gracefully.")
# Explicitly stop LlamaServer to avoid __del__ cleanup issues during shutdown
@ -83,6 +55,37 @@ signal.signal(signal.SIGTERM, signal_handler)
def create_interface():
import shutil
import gradio as gr
from modules import (
training,
ui,
ui_chat,
ui_default,
ui_file_saving,
ui_image_generation,
ui_model_menu,
ui_notebook,
ui_parameters,
ui_session,
)
from modules.chat import generate_pfp_cache
from modules.extensions import apply_extensions
from modules.utils import gradio
warnings.filterwarnings('ignore', category=UserWarning, message='The value passed into gr.Dropdown()')
# Set up Gradio temp directory path
gradio_temp_path = shared.user_data_dir / 'cache' / 'gradio'
shutil.rmtree(gradio_temp_path, ignore_errors=True)
gradio_temp_path.mkdir(parents=True, exist_ok=True)
os.environ.update({
'GRADIO_ANALYTICS_ENABLED': 'False',
'GRADIO_TEMP_DIR': str(gradio_temp_path)
})
title = 'Text Generation Web UI'
# Password authentication
@ -215,6 +218,10 @@ def create_interface():
shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, gradio(ui.list_interface_input_elements()), show_progress=False)
# Sync theme_state with the actual client-side theme so that
# autosave always writes the correct dark_theme value.
shared.gradio['interface'].load(None, None, gradio('theme_state'), js='() => document.body.classList.contains("dark") ? "dark" : "light"')
extensions_module.create_extensions_tabs() # Extensions tabs
extensions_module.create_extensions_block() # Extensions block

View file

@ -5,6 +5,7 @@ setlocal enabledelayedexpansion
set PYTHONNOUSERSITE=1
set PYTHONPATH=
set PYTHONHOME=
set PYTHONUTF8=1
cd /D "%~dp0"

View file

@ -1 +0,0 @@
min_p: 0.2

View file

@ -1,3 +0,0 @@
temperature: 0.7
top_p: 0.8
top_k: 20

View file

@ -1,3 +0,0 @@
temperature: 0.6
top_p: 0.95
top_k: 20

View file

@ -0,0 +1 @@
top_p: 0.95

View file

@ -1 +0,0 @@
min_p: 0.05

View file

@ -0,0 +1,52 @@
import ast
import operator
OPERATORS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.Pow: operator.pow,
ast.Mod: operator.mod,
ast.USub: operator.neg,
}
def _eval(node):
if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
return node.value
elif isinstance(node, ast.BinOp) and type(node.op) in OPERATORS:
left = _eval(node.left)
right = _eval(node.right)
if isinstance(node.op, ast.Pow) and isinstance(right, (int, float)) and abs(right) > 10000:
raise ValueError("Exponent too large (max 10000)")
return OPERATORS[type(node.op)](left, right)
elif isinstance(node, ast.UnaryOp) and type(node.op) in OPERATORS:
return OPERATORS[type(node.op)](_eval(node.operand))
raise ValueError(f"Unsupported expression")
tool = {
"type": "function",
"function": {
"name": "calculate",
"description": "Evaluate a math expression. Supports +, -, *, /, **, %.",
"parameters": {
"type": "object",
"properties": {
"expression": {"type": "string", "description": "The math expression to evaluate (e.g. '2 * (3 + 4)')."},
},
"required": ["expression"]
}
}
}
def execute(arguments):
expr = arguments.get("expression", "")
try:
tree = ast.parse(expr, mode='eval')
result = _eval(tree.body)
return {"expression": expr, "result": result}
except Exception as e:
return {"error": str(e)}

View file

@ -0,0 +1,30 @@
from modules.web_search import download_web_page, truncate_content_by_tokens
tool = {
"type": "function",
"function": {
"name": "fetch_webpage",
"description": "Fetch and read the contents of a web page given its URL. Returns the page content as plain text.",
"parameters": {
"type": "object",
"properties": {
"url": {"type": "string", "description": "The URL of the web page to fetch."},
"max_tokens": {"type": "integer", "description": "Maximum number of tokens in the returned content (default: 2048)."},
},
"required": ["url"]
}
}
}
def execute(arguments):
url = arguments.get("url", "")
max_tokens = arguments.get("max_tokens", 2048)
if not url:
return {"error": "No URL provided."}
content = download_web_page(url, include_links=True)
if not content or not content.strip():
return {"error": f"Failed to fetch content from {url}"}
return {"url": url, "content": truncate_content_by_tokens(content, max_tokens=max_tokens)}

View file

@ -0,0 +1,18 @@
from datetime import datetime
tool = {
"type": "function",
"function": {
"name": "get_datetime",
"description": "Get the current date and time.",
"parameters": {
"type": "object",
"properties": {},
}
}
}
def execute(arguments):
now = datetime.now()
return {"date": now.strftime("%Y-%m-%d"), "time": now.strftime("%I:%M %p")}

View file

@ -0,0 +1,23 @@
import random
tool = {
"type": "function",
"function": {
"name": "roll_dice",
"description": "Roll one or more dice with the specified number of sides.",
"parameters": {
"type": "object",
"properties": {
"count": {"type": "integer", "description": "Number of dice to roll.", "default": 1},
"sides": {"type": "integer", "description": "Number of sides per die.", "default": 20},
},
}
}
}
def execute(arguments):
count = max(1, min(arguments.get("count", 1), 1000))
sides = max(2, min(arguments.get("sides", 20), 1000))
rolls = [random.randint(1, sides) for _ in range(count)]
return {"rolls": rolls, "total": sum(rolls)}

View file

@ -0,0 +1,27 @@
from modules.web_search import perform_web_search
tool = {
"type": "function",
"function": {
"name": "web_search",
"description": "Search the web using DuckDuckGo and return a list of result titles and URLs. Use fetch_webpage to read the contents of a specific result.",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "The search query."},
},
"required": ["query"]
}
}
}
def execute(arguments):
query = arguments.get("query", "")
results = perform_web_search(query, num_pages=None, fetch_content=False)
output = []
for r in results:
if r:
output.append({"title": r["title"], "url": r["url"]})
return output if output else [{"error": "No results found."}]