mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-20 04:14:38 +01:00
commit
88a318894c
|
|
@ -148,11 +148,11 @@ jobs:
|
|||
# 6. Create archive
|
||||
cd ..
|
||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm.zip"
|
||||
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm7.2.zip"
|
||||
echo "Creating archive: $ARCHIVE_NAME"
|
||||
powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ARCHIVE_NAME"
|
||||
else
|
||||
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm.tar.gz"
|
||||
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm7.2.tar.gz"
|
||||
echo "Creating archive: $ARCHIVE_NAME"
|
||||
tar czf "$ARCHIVE_NAME" "text-generation-webui-${VERSION_CLEAN}"
|
||||
fi
|
||||
|
|
|
|||
157
README.md
157
README.md
|
|
@ -13,7 +13,7 @@
|
|||
|
||||
# Text Generation Web UI
|
||||
|
||||
A Gradio web UI for running Large Language Models locally. 100% private, offline, and free.
|
||||
A Gradio web UI for running Large Language Models locally. 100% private and offline. Supports text generation, vision, tool-calling, training, image generation, and more.
|
||||
|
||||
[Try the Deep Reason extension](https://oobabooga.gumroad.com/l/deep_reason)
|
||||
|
||||
|
|
@ -23,22 +23,21 @@ A Gradio web UI for running Large Language Models locally. 100% private, offline
|
|||
|
||||
## Features
|
||||
|
||||
- Supports multiple local text generation backends, including [llama.cpp](https://github.com/ggerganov/llama.cpp), [Transformers](https://github.com/huggingface/transformers), [ExLlamaV3](https://github.com/turboderp-org/exllamav3), and [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) (the latter via its own [Dockerfile](https://github.com/oobabooga/text-generation-webui/blob/main/docker/TensorRT-LLM/Dockerfile)).
|
||||
- Easy setup: Choose between **portable builds** (zero setup, just unzip and run) for GGUF models on Windows/Linux/macOS, or the one-click installer that creates a self-contained `installer_files` directory.
|
||||
- **Multiple backends**: [llama.cpp](https://github.com/ggerganov/llama.cpp), [Transformers](https://github.com/huggingface/transformers), [ExLlamaV3](https://github.com/turboderp-org/exllamav3), and [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM). Switch between backends and models without restarting.
|
||||
- **File attachments**: Upload text files, PDF documents, and .docx documents to talk about their contents.
|
||||
- **Vision (multimodal)**: Attach images to messages for visual understanding ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Multimodal-Tutorial)).
|
||||
- **Tool-calling**: Models can call custom functions during chat — web search, page fetching, math, and more. Each tool is a single `.py` file, easy to create and extend ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Tool-Calling-Tutorial)).
|
||||
- **OpenAI-compatible API**: Chat and Completions endpoints with tool-calling support. Use as a local drop-in replacement for the OpenAI API ([examples](https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples)).
|
||||
- **Training**: Fine-tune LoRAs on multi-turn chat or raw text datasets. Supports resuming interrupted runs ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/05-%E2%80%90-Training-Tab)).
|
||||
- **Image generation**: A dedicated tab for `diffusers` models like **Z-Image-Turbo**. Features 4-bit/8-bit quantization and a persistent gallery with metadata ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Image-Generation-Tutorial)).
|
||||
- **Easy setup**: [Portable builds](https://github.com/oobabooga/text-generation-webui/releases) (zero setup, just unzip and run) for GGUF models on Windows/Linux/macOS, or a one-click installer for the full feature set.
|
||||
- 100% offline and private, with zero telemetry, external resources, or remote update requests.
|
||||
- `instruct` mode for instruction-following (like ChatGPT), and `chat-instruct`/`chat` modes for talking to custom characters. Prompts are automatically formatted with Jinja2 templates.
|
||||
- **File attachments**: Upload text files, PDF documents, and .docx documents to talk about their contents.
|
||||
- **Vision (multimodal models)**: Attach images to messages for visual understanding ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Multimodal-Tutorial)).
|
||||
- **Image generation**: A dedicated tab for `diffusers` models like **Z-Image-Turbo**. Features 4-bit/8-bit quantization and a persistent gallery with metadata ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Image-Generation-Tutorial)).
|
||||
- **Web search**: Optionally search the internet with LLM-generated queries to add context to the conversation.
|
||||
- **Training**: Fine-tune LoRAs on multi-turn chat or raw text datasets. Easy to use, good defaults, and supports resuming interrupted runs ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/05-%E2%80%90-Training-Tab)).
|
||||
- Edit messages, navigate between message versions, and branch conversations at any point.
|
||||
- Switch between different models in the UI without restarting.
|
||||
- Free-form text generation in the Notebook tab without being limited to chat turns.
|
||||
- Multiple sampling parameters and generation options for sophisticated text generation control.
|
||||
- Aesthetic UI with dark and light themes.
|
||||
- Syntax highlighting for code blocks and LaTeX rendering for mathematical expressions.
|
||||
- OpenAI-compatible API with Chat and Completions endpoints, including tool-calling support – see [examples](https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples).
|
||||
- Extension support, with numerous built-in and user-contributed extensions available. See the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/07-%E2%80%90-Extensions) and [extensions directory](https://github.com/oobabooga/text-generation-webui-extensions) for details.
|
||||
|
||||
## How to install
|
||||
|
|
@ -47,10 +46,11 @@ A Gradio web UI for running Large Language Models locally. 100% private, offline
|
|||
|
||||
No installation needed – just download, unzip and run. All dependencies included.
|
||||
|
||||
Compatible with GGUF (llama.cpp) models on Windows, Linux, and macOS. [Check what models fit your hardware](https://huggingface.co/spaces/oobabooga/accurate-gguf-vram-calculator).
|
||||
|
||||
Download from here: **https://github.com/oobabooga/text-generation-webui/releases**
|
||||
|
||||
- Builds are provided for Linux, Windows, and macOS, with options for CUDA, Vulkan, ROCm, and CPU-only.
|
||||
- Compatible with GGUF (llama.cpp) models.
|
||||
|
||||
#### Option 2: Manual portable install with venv
|
||||
|
||||
Very fast setup that should work on any Python 3.9+:
|
||||
|
|
@ -81,7 +81,7 @@ deactivate
|
|||
|
||||
#### Option 3: One-click installer
|
||||
|
||||
For users who need additional backends (ExLlamaV3, Transformers) or extensions (TTS, voice input, translation, etc). Requires ~10GB disk space and downloads PyTorch.
|
||||
For users who need additional backends (ExLlamaV3, Transformers), training, image generation, or extensions (TTS, voice input, translation, etc). Requires ~10GB disk space and downloads PyTorch.
|
||||
|
||||
1. Clone the repository, or [download its source code](https://github.com/oobabooga/text-generation-webui/archive/refs/heads/main.zip) and extract it.
|
||||
2. Run the startup script for your OS: `start_windows.bat`, `start_linux.sh`, or `start_macos.sh`.
|
||||
|
|
@ -146,7 +146,7 @@ conda activate textgen
|
|||
|--------|---------|---------|
|
||||
| Linux/WSL | NVIDIA | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128` |
|
||||
| Linux/WSL | CPU only | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cpu` |
|
||||
| Linux | AMD | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/rocm6.4` |
|
||||
| Linux | AMD | `pip3 install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.9.1%2Brocm7.2.0.lw.git7e1940d4-cp313-cp313-linux_x86_64.whl` |
|
||||
| MacOS + MPS | Any | `pip3 install torch==2.9.1` |
|
||||
| Windows | NVIDIA | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128` |
|
||||
| Windows | CPU only | `pip3 install torch==2.9.1` |
|
||||
|
|
@ -201,7 +201,7 @@ ln -s docker/{nvidia/Dockerfile,nvidia/docker-compose.yml,.dockerignore} .
|
|||
For AMD GPU:
|
||||
ln -s docker/{amd/Dockerfile,amd/docker-compose.yml,.dockerignore} .
|
||||
For Intel GPU:
|
||||
ln -s docker/{intel/Dockerfile,amd/docker-compose.yml,.dockerignore} .
|
||||
ln -s docker/{intel/Dockerfile,intel/docker-compose.yml,.dockerignore} .
|
||||
For CPU only
|
||||
ln -s docker/{cpu/Dockerfile,cpu/docker-compose.yml,.dockerignore} .
|
||||
cp docker/.env.example .env
|
||||
|
|
@ -236,20 +236,24 @@ List of command-line flags
|
|||
</summary>
|
||||
|
||||
```txt
|
||||
usage: server.py [-h] [--multi-user] [--model MODEL] [--lora LORA [LORA ...]] [--model-dir MODEL_DIR] [--lora-dir LORA_DIR] [--model-menu] [--settings SETTINGS]
|
||||
usage: server.py [-h] [--user-data-dir USER_DATA_DIR] [--multi-user] [--model MODEL] [--lora LORA [LORA ...]] [--model-dir MODEL_DIR] [--lora-dir LORA_DIR] [--model-menu] [--settings SETTINGS]
|
||||
[--extensions EXTENSIONS [EXTENSIONS ...]] [--verbose] [--idle-timeout IDLE_TIMEOUT] [--image-model IMAGE_MODEL] [--image-model-dir IMAGE_MODEL_DIR] [--image-dtype {bfloat16,float16}]
|
||||
[--image-attn-backend {flash_attention_2,sdpa}] [--image-cpu-offload] [--image-compile] [--image-quant {none,bnb-8bit,bnb-4bit,torchao-int8wo,torchao-fp4,torchao-float8wo}]
|
||||
[--loader LOADER] [--ctx-size N] [--cache-type N] [--model-draft MODEL_DRAFT] [--draft-max DRAFT_MAX] [--gpu-layers-draft GPU_LAYERS_DRAFT] [--device-draft DEVICE_DRAFT]
|
||||
[--ctx-size-draft CTX_SIZE_DRAFT] [--spec-type {none,ngram-mod,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-cache}] [--spec-ngram-size-n SPEC_NGRAM_SIZE_N]
|
||||
[--spec-ngram-size-m SPEC_NGRAM_SIZE_M] [--spec-ngram-min-hits SPEC_NGRAM_MIN_HITS] [--gpu-layers N] [--cpu-moe] [--mmproj MMPROJ] [--streaming-llm] [--tensor-split TENSOR_SPLIT]
|
||||
[--row-split] [--no-mmap] [--mlock] [--no-kv-offload] [--batch-size BATCH_SIZE] [--ubatch-size UBATCH_SIZE] [--threads THREADS] [--threads-batch THREADS_BATCH] [--numa]
|
||||
[--extra-flags EXTRA_FLAGS] [--cpu] [--cpu-memory CPU_MEMORY] [--disk] [--disk-cache-dir DISK_CACHE_DIR] [--load-in-8bit] [--bf16] [--no-cache] [--trust-remote-code]
|
||||
[--force-safetensors] [--no_use_fast] [--attn-implementation IMPLEMENTATION] [--load-in-4bit] [--use_double_quant] [--compute_dtype COMPUTE_DTYPE] [--quant_type QUANT_TYPE]
|
||||
[--gpu-split GPU_SPLIT] [--enable-tp] [--tp-backend TP_BACKEND] [--cfg-cache] [--cpp-runner]
|
||||
[--alpha_value ALPHA_VALUE] [--rope_freq_base ROPE_FREQ_BASE] [--compress_pos_emb COMPRESS_POS_EMB] [--listen] [--listen-port LISTEN_PORT] [--listen-host LISTEN_HOST] [--share]
|
||||
[--parallel PARALLEL] [--fit-target FIT_TARGET] [--extra-flags EXTRA_FLAGS] [--cpu] [--cpu-memory CPU_MEMORY] [--disk] [--disk-cache-dir DISK_CACHE_DIR] [--load-in-8bit] [--bf16]
|
||||
[--no-cache] [--trust-remote-code] [--force-safetensors] [--no_use_fast] [--attn-implementation IMPLEMENTATION] [--load-in-4bit] [--use_double_quant] [--compute_dtype COMPUTE_DTYPE]
|
||||
[--quant_type QUANT_TYPE] [--gpu-split GPU_SPLIT] [--enable-tp] [--tp-backend TP_BACKEND] [--cfg-cache] [--listen] [--listen-port LISTEN_PORT] [--listen-host LISTEN_HOST] [--share]
|
||||
[--auto-launch] [--gradio-auth GRADIO_AUTH] [--gradio-auth-path GRADIO_AUTH_PATH] [--ssl-keyfile SSL_KEYFILE] [--ssl-certfile SSL_CERTFILE] [--subpath SUBPATH] [--old-colors]
|
||||
[--portable] [--api] [--public-api] [--public-api-id PUBLIC_API_ID] [--api-port API_PORT] [--api-key API_KEY] [--admin-key ADMIN_KEY] [--api-enable-ipv6] [--api-disable-ipv4]
|
||||
[--nowebui]
|
||||
[--nowebui] [--temperature N] [--dynatemp-low N] [--dynatemp-high N] [--dynatemp-exponent N] [--smoothing-factor N] [--smoothing-curve N] [--min-p N] [--top-p N] [--top-k N]
|
||||
[--typical-p N] [--xtc-threshold N] [--xtc-probability N] [--epsilon-cutoff N] [--eta-cutoff N] [--tfs N] [--top-a N] [--top-n-sigma N] [--adaptive-target N] [--adaptive-decay N]
|
||||
[--dry-multiplier N] [--dry-allowed-length N] [--dry-base N] [--repetition-penalty N] [--frequency-penalty N] [--presence-penalty N] [--encoder-repetition-penalty N]
|
||||
[--no-repeat-ngram-size N] [--repetition-penalty-range N] [--penalty-alpha N] [--guidance-scale N] [--mirostat-mode N] [--mirostat-tau N] [--mirostat-eta N]
|
||||
[--do-sample | --no-do-sample] [--dynamic-temperature | --no-dynamic-temperature] [--temperature-last | --no-temperature-last] [--sampler-priority N] [--dry-sequence-breakers N]
|
||||
[--enable-thinking | --no-enable-thinking] [--reasoning-effort N] [--chat-template-file CHAT_TEMPLATE_FILE]
|
||||
|
||||
Text Generation Web UI
|
||||
|
||||
|
|
@ -257,7 +261,8 @@ options:
|
|||
-h, --help show this help message and exit
|
||||
|
||||
Basic settings:
|
||||
--multi-user Multi-user mode. Chat histories are not saved or automatically loaded. Warning: this is likely not safe for sharing publicly.
|
||||
--user-data-dir USER_DATA_DIR Path to the user data directory. Default: auto-detected.
|
||||
--multi-user Multi-user mode. Chat histories are not saved or automatically loaded. Best suited for small trusted teams.
|
||||
--model MODEL Name of the model to load by default.
|
||||
--lora LORA [LORA ...] The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.
|
||||
--model-dir MODEL_DIR Path to directory with all the models.
|
||||
|
|
@ -280,12 +285,12 @@ Image model:
|
|||
Quantization method for image model.
|
||||
|
||||
Model loader:
|
||||
--loader LOADER Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav3,
|
||||
TensorRT-LLM.
|
||||
--loader LOADER Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav3, TensorRT-
|
||||
LLM.
|
||||
|
||||
Context and cache:
|
||||
--ctx-size N, --n_ctx N, --max_seq_len N Context size in tokens. llama.cpp: 0 = auto if gpu-layers is also -1.
|
||||
--cache-type N, --cache_type N KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8).
|
||||
--ctx-size, --n_ctx, --max_seq_len N Context size in tokens. 0 = auto for llama.cpp (requires gpu-layers=-1), 8192 for other loaders.
|
||||
--cache-type, --cache_type N KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8).
|
||||
|
||||
Speculative decoding:
|
||||
--model-draft MODEL_DRAFT Path to the draft model for speculative decoding.
|
||||
|
|
@ -300,7 +305,7 @@ Speculative decoding:
|
|||
--spec-ngram-min-hits SPEC_NGRAM_MIN_HITS Minimum n-gram hits for ngram-map speculative decoding.
|
||||
|
||||
llama.cpp:
|
||||
--gpu-layers N, --n-gpu-layers N Number of layers to offload to the GPU. -1 = auto.
|
||||
--gpu-layers, --n-gpu-layers N Number of layers to offload to the GPU. -1 = auto.
|
||||
--cpu-moe Move the experts to the CPU (for MoE models).
|
||||
--mmproj MMPROJ Path to the mmproj file for vision models.
|
||||
--streaming-llm Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.
|
||||
|
|
@ -314,13 +319,17 @@ llama.cpp:
|
|||
--threads THREADS Number of threads to use.
|
||||
--threads-batch THREADS_BATCH Number of threads to use for batches/prompt processing.
|
||||
--numa Activate NUMA task allocation for llama.cpp.
|
||||
--parallel PARALLEL Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set
|
||||
ctx_size to 32768.
|
||||
--fit-target FIT_TARGET Target VRAM margin per device for auto GPU layers, comma-separated list of values in MiB. A single value is broadcast across all devices.
|
||||
Default: 1024.
|
||||
--extra-flags EXTRA_FLAGS Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"
|
||||
|
||||
Transformers/Accelerate:
|
||||
--cpu Use the CPU to generate text. Warning: Training on CPU is extremely slow.
|
||||
--cpu-memory CPU_MEMORY Maximum CPU memory in GiB. Use this for CPU offloading.
|
||||
--disk If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.
|
||||
--disk-cache-dir DISK_CACHE_DIR Directory to save the disk cache to. Defaults to "user_data/cache".
|
||||
--disk-cache-dir DISK_CACHE_DIR Directory to save the disk cache to.
|
||||
--load-in-8bit Load the model with 8-bit precision (using bitsandbytes).
|
||||
--bf16 Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
|
||||
--no-cache Set use_cache to False while generating text. This reduces VRAM usage slightly, but it comes at a performance cost.
|
||||
|
|
@ -341,14 +350,6 @@ ExLlamaV3:
|
|||
--tp-backend TP_BACKEND The backend for tensor parallelism. Valid options: native, nccl. Default: native.
|
||||
--cfg-cache Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.
|
||||
|
||||
TensorRT-LLM:
|
||||
--cpp-runner Use the ModelRunnerCpp runner, which is faster than the default ModelRunner.
|
||||
|
||||
RoPE:
|
||||
--alpha_value ALPHA_VALUE Positional embeddings alpha factor for NTK RoPE scaling. Use either this or compress_pos_emb, not both.
|
||||
--rope_freq_base ROPE_FREQ_BASE If greater than 0, will be used instead of alpha_value. Those two are related by rope_freq_base = 10000 * alpha_value ^ (64 / 63).
|
||||
--compress_pos_emb COMPRESS_POS_EMB Positional embeddings compression factor. Should be set to (context length) / (model's original context length). Equal to 1/rope_freq_scale.
|
||||
|
||||
Gradio:
|
||||
--listen Make the web UI reachable from your local network.
|
||||
--listen-port LISTEN_PORT The listening port that the server will use.
|
||||
|
|
@ -365,7 +366,7 @@ Gradio:
|
|||
|
||||
API:
|
||||
--api Enable the API extension.
|
||||
--public-api Create a public URL for the API using Cloudfare.
|
||||
--public-api Create a public URL for the API using Cloudflare.
|
||||
--public-api-id PUBLIC_API_ID Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.
|
||||
--api-port API_PORT The listening port for the API.
|
||||
--api-key API_KEY API authentication key.
|
||||
|
|
@ -373,28 +374,67 @@ API:
|
|||
--api-enable-ipv6 Enable IPv6 for the API
|
||||
--api-disable-ipv4 Disable IPv4 for the API
|
||||
--nowebui Do not launch the Gradio UI. Useful for launching the API in standalone mode.
|
||||
|
||||
API generation defaults:
|
||||
--temperature N Temperature
|
||||
--dynatemp-low N Dynamic temperature low
|
||||
--dynatemp-high N Dynamic temperature high
|
||||
--dynatemp-exponent N Dynamic temperature exponent
|
||||
--smoothing-factor N Smoothing factor
|
||||
--smoothing-curve N Smoothing curve
|
||||
--min-p N Min P
|
||||
--top-p N Top P
|
||||
--top-k N Top K
|
||||
--typical-p N Typical P
|
||||
--xtc-threshold N XTC threshold
|
||||
--xtc-probability N XTC probability
|
||||
--epsilon-cutoff N Epsilon cutoff
|
||||
--eta-cutoff N Eta cutoff
|
||||
--tfs N TFS
|
||||
--top-a N Top A
|
||||
--top-n-sigma N Top N Sigma
|
||||
--adaptive-target N Adaptive target
|
||||
--adaptive-decay N Adaptive decay
|
||||
--dry-multiplier N DRY multiplier
|
||||
--dry-allowed-length N DRY allowed length
|
||||
--dry-base N DRY base
|
||||
--repetition-penalty N Repetition penalty
|
||||
--frequency-penalty N Frequency penalty
|
||||
--presence-penalty N Presence penalty
|
||||
--encoder-repetition-penalty N Encoder repetition penalty
|
||||
--no-repeat-ngram-size N No repeat ngram size
|
||||
--repetition-penalty-range N Repetition penalty range
|
||||
--penalty-alpha N Penalty alpha
|
||||
--guidance-scale N Guidance scale
|
||||
--mirostat-mode N Mirostat mode
|
||||
--mirostat-tau N Mirostat tau
|
||||
--mirostat-eta N Mirostat eta
|
||||
--do-sample, --no-do-sample Do sample
|
||||
--dynamic-temperature, --no-dynamic-temperature Dynamic temperature
|
||||
--temperature-last, --no-temperature-last Temperature last
|
||||
--sampler-priority N Sampler priority
|
||||
--dry-sequence-breakers N DRY sequence breakers
|
||||
--enable-thinking, --no-enable-thinking Enable thinking
|
||||
--reasoning-effort N Reasoning effort
|
||||
--chat-template-file CHAT_TEMPLATE_FILE Path to a chat template file (.jinja, .jinja2, or .yaml) to use as the default instruction template for API requests. Overrides the model's
|
||||
built-in template.
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Downloading models
|
||||
|
||||
Models should be placed in the folder `text-generation-webui/user_data/models`. They are usually downloaded from [Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads&search=gguf).
|
||||
1. Download a GGUF model file from [Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads&search=gguf).
|
||||
2. Place it in the `user_data/models` folder.
|
||||
|
||||
To check if a GGUF model will fit in your hardware before downloading it, you can use this tool I created:
|
||||
That's it. The UI will detect it automatically.
|
||||
|
||||
[Accurate GGUF VRAM Calculator](https://huggingface.co/spaces/oobabooga/accurate-gguf-vram-calculator)
|
||||
To check what will fit your GPU, you can use the [VRAM Calculator](https://huggingface.co/spaces/oobabooga/accurate-gguf-vram-calculator).
|
||||
|
||||
* GGUF models are a single file and should be placed directly into `user_data/models`. Example:
|
||||
<details>
|
||||
<summary>Other model types (Transformers, EXL3)</summary>
|
||||
|
||||
```
|
||||
text-generation-webui
|
||||
└── user_data
|
||||
└── models
|
||||
└── llama-2-13b-chat.Q4_K_M.gguf
|
||||
```
|
||||
|
||||
* The remaining model types (like 16-bit Transformers models and EXL3 models) are made of several files and must be placed in a subfolder. Example:
|
||||
Models that consist of multiple files (like 16-bit Transformers models and EXL3 models) should be placed in a subfolder inside `user_data/models`:
|
||||
|
||||
```
|
||||
text-generation-webui
|
||||
|
|
@ -404,31 +444,18 @@ text-generation-webui
|
|||
├── config.json
|
||||
├── generation_config.json
|
||||
├── model-00001-of-00004.safetensors
|
||||
├── model-00002-of-00004.safetensors
|
||||
├── model-00003-of-00004.safetensors
|
||||
├── model-00004-of-00004.safetensors
|
||||
├── model.safetensors.index.json
|
||||
├── special_tokens_map.json
|
||||
├── ...
|
||||
├── tokenizer_config.json
|
||||
└── tokenizer.json
|
||||
```
|
||||
|
||||
In both cases, you can use the "Model" tab of the UI to download the model from Hugging Face automatically. It is also possible to download it via the command-line with:
|
||||
|
||||
```
|
||||
python download-model.py organization/model
|
||||
```
|
||||
|
||||
Run `python download-model.py --help` to see all the options.
|
||||
These formats require the one-click installer (not the portable build).
|
||||
</details>
|
||||
|
||||
## Documentation
|
||||
|
||||
https://github.com/oobabooga/text-generation-webui/wiki
|
||||
|
||||
## Google Colab notebook
|
||||
|
||||
https://colab.research.google.com/github/oobabooga/text-generation-webui/blob/main/Colab-TextGen-GPU.ipynb
|
||||
|
||||
## Community
|
||||
|
||||
https://www.reddit.com/r/Oobabooga/
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ set INSTALL_ENV_DIR=%cd%\installer_files\env
|
|||
set PYTHONNOUSERSITE=1
|
||||
set PYTHONPATH=
|
||||
set PYTHONHOME=
|
||||
set PYTHONUTF8=1
|
||||
set "CUDA_PATH=%INSTALL_ENV_DIR%"
|
||||
set "CUDA_HOME=%CUDA_PATH%"
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
display: grid;
|
||||
align-items: start;
|
||||
grid-template-columns: 60px minmax(0, 1fr);
|
||||
width: min(100%, calc(724px + 60px));
|
||||
padding-bottom: 22px;
|
||||
padding-top: 6px;
|
||||
font-size: 18px;
|
||||
|
|
@ -91,9 +92,6 @@
|
|||
}
|
||||
|
||||
.message-body p {
|
||||
margin-bottom: 0 !important;
|
||||
font-size: 16px !important;
|
||||
line-height: 1.5 !important;
|
||||
color: #e0e0e0 !important; /* Light color for text */
|
||||
}
|
||||
|
||||
|
|
@ -122,7 +120,7 @@
|
|||
}
|
||||
|
||||
.message-body p {
|
||||
font-size: 14px !important; /* Smaller text for mobile */
|
||||
font-size: 14px !important;
|
||||
}
|
||||
|
||||
.username {
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
display: grid;
|
||||
align-items: start;
|
||||
grid-template-columns: 60px minmax(0, 1fr);
|
||||
width: min(100%, calc(724px + 60px + 90px));
|
||||
padding-bottom: 21px;
|
||||
padding-top: 7px;
|
||||
font-size: 18px;
|
||||
|
|
@ -86,10 +87,8 @@
|
|||
border-radius: 20px;
|
||||
}
|
||||
|
||||
.message-body p {
|
||||
margin-bottom: 0 !important;
|
||||
.message-body p, .message-body li {
|
||||
font-size: 18px !important;
|
||||
line-height: 1.428571429 !important;
|
||||
color: rgb(243 244 246) !important;
|
||||
text-shadow: 2px 2px 2px rgb(0 0 0);
|
||||
font-weight: 500;
|
||||
|
|
@ -127,7 +126,7 @@
|
|||
padding-left: 0;
|
||||
}
|
||||
|
||||
.message-body p {
|
||||
.message-body p, .message-body li {
|
||||
font-size: 16px !important;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -19,4 +19,5 @@
|
|||
padding-bottom: 1.5em;
|
||||
padding-top: 0.5em;
|
||||
grid-template-columns: 70px minmax(0, 1fr);
|
||||
width: min(100%, calc(724px + 70px));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
display: grid;
|
||||
align-items: start;
|
||||
grid-template-columns: 60px minmax(0, 1fr);
|
||||
width: min(100%, calc(724px + 60px));
|
||||
padding-bottom: 1.5em;
|
||||
padding-top: 0.5em;
|
||||
font-size: 15px;
|
||||
|
|
@ -46,16 +47,10 @@
|
|||
border-radius: 20px;
|
||||
}
|
||||
|
||||
.message-body p {
|
||||
font-size: 15px !important;
|
||||
line-height: 22.5px !important;
|
||||
.message-body p, .message-body li {
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.message-body p, .chat .message-body ul, .chat .message-body ol {
|
||||
margin-bottom: 10px !important;
|
||||
}
|
||||
|
||||
.dark .message-body p em {
|
||||
color: rgb(138 138 138) !important;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
.message {
|
||||
width: min(100%, calc(724px + 60px));
|
||||
padding-bottom: 22px;
|
||||
padding-top: 3px;
|
||||
font-size: 15px;
|
||||
|
|
@ -60,8 +61,10 @@
|
|||
text-align: right;
|
||||
}
|
||||
|
||||
.dark .circle-bot + .text div, .dark .circle-bot + .text * {
|
||||
color: #000;
|
||||
.dark .circle-bot + .text div, .dark .circle-bot + .text *,
|
||||
.dark .chat .message .circle-bot + .text .message-body :is(h1, h2, h3, h4, h5, h6),
|
||||
.dark .chat .message .circle-bot + .text .message-body a {
|
||||
color: #000 !important;
|
||||
}
|
||||
|
||||
.text {
|
||||
|
|
@ -76,19 +79,14 @@
|
|||
font-weight: bold;
|
||||
}
|
||||
|
||||
.message-body {
|
||||
}
|
||||
|
||||
.message-body img {
|
||||
max-width: 300px;
|
||||
max-height: 300px;
|
||||
border-radius: 20px;
|
||||
}
|
||||
|
||||
.message-body p {
|
||||
margin-bottom: 0 !important;
|
||||
.message-body p, .message-body li {
|
||||
font-size: 15px !important;
|
||||
line-height: 1.428571429 !important;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
.message {
|
||||
display: block;
|
||||
width: min(100%, 724px);
|
||||
padding-top: 0;
|
||||
padding-bottom: 21px;
|
||||
font-size: 15px;
|
||||
|
|
@ -77,14 +78,8 @@
|
|||
border-radius: 12px;
|
||||
}
|
||||
|
||||
.message-body p {
|
||||
.message-body p, .message-body li {
|
||||
font-size: 15px !important;
|
||||
line-height: 1.4 !important;
|
||||
font-weight: 400;
|
||||
}
|
||||
|
||||
.message-body p:first-child {
|
||||
margin-top: 0 !important;
|
||||
}
|
||||
|
||||
.dark .message-body p em {
|
||||
|
|
@ -100,6 +95,3 @@
|
|||
margin-top: 8px;
|
||||
}
|
||||
|
||||
.message-body p, .chat .message-body ul, .chat .message-body ol {
|
||||
margin-bottom: 10px !important;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@
|
|||
|
||||
.chat .user-message .text,
|
||||
.chat .assistant-message .text {
|
||||
max-width: 700px;
|
||||
max-width: 724px;
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
}
|
||||
|
|
|
|||
165
css/main.css
165
css/main.css
|
|
@ -400,7 +400,6 @@ audio {
|
|||
}
|
||||
|
||||
.chat .message {
|
||||
width: min(100%, 48rem);
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
text-align: start;
|
||||
|
|
@ -431,10 +430,19 @@ audio {
|
|||
font-size: 16px;
|
||||
}
|
||||
|
||||
.dark .message-body :is(h1, h2, h3, h4, h5, h6) {
|
||||
.dark .message-body h1,
|
||||
.dark .message-body h2,
|
||||
.dark .message-body h3,
|
||||
.dark .message-body h4,
|
||||
.dark .message-body h5,
|
||||
.dark .message-body h6 {
|
||||
color: white !important;
|
||||
}
|
||||
|
||||
.dark .message-body blockquote {
|
||||
border-left-color: rgb(255 255 255 / 30%);
|
||||
}
|
||||
|
||||
.message-body h1 {
|
||||
font-weight: 800;
|
||||
font-size: 2.25em;
|
||||
|
|
@ -831,9 +839,20 @@ audio {
|
|||
}
|
||||
}
|
||||
|
||||
.message-body ol, .message-body ul {
|
||||
.message-body p, .message-body li {
|
||||
line-height: 1.75 !important;
|
||||
}
|
||||
|
||||
.message-body p, .message-body ul, .message-body ol {
|
||||
margin: 1.25em 0 !important;
|
||||
}
|
||||
|
||||
.message-body :is(p, ul, ol):first-child {
|
||||
margin-top: 0 !important;
|
||||
margin-bottom: 1.25em !important;
|
||||
}
|
||||
|
||||
.message-body :is(p, ul, ol):last-child {
|
||||
margin-bottom: 0 !important;
|
||||
}
|
||||
|
||||
/* ----------------------------------------------
|
||||
|
|
@ -1003,6 +1022,49 @@ audio {
|
|||
padding-right: 0.5rem;
|
||||
}
|
||||
|
||||
#new-chat-wrapper {
|
||||
display: contents;
|
||||
}
|
||||
|
||||
.new-chat-arrow {
|
||||
cursor: pointer;
|
||||
position: relative;
|
||||
padding: 0;
|
||||
margin-right: -15px;
|
||||
height: 39.594px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.new-chat-menu {
|
||||
display: none;
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
padding-top: 1.2em;
|
||||
z-index: var(--layer-top);
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.new-chat-arrow:hover .new-chat-menu {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.new-chat-menu-item {
|
||||
cursor: pointer;
|
||||
padding: var(--size-2);
|
||||
background: var(--background-fill-primary);
|
||||
box-shadow: var(--shadow-drop-lg);
|
||||
border-radius: var(--container-radius);
|
||||
color: var(--body-text-color);
|
||||
font-size: var(--text-md);
|
||||
font-weight: var(--button-large-text-weight);
|
||||
}
|
||||
|
||||
.new-chat-menu-item:hover {
|
||||
background: var(--background-fill-secondary);
|
||||
}
|
||||
|
||||
#past-chats-row,
|
||||
#chat-controls {
|
||||
width: 260px;
|
||||
|
|
@ -1373,7 +1435,6 @@ audio {
|
|||
overflow-wrap: break-word;
|
||||
max-height: 250px;
|
||||
overflow-y: scroll;
|
||||
contain: layout;
|
||||
}
|
||||
|
||||
.chat .message-body .thinking-content p,
|
||||
|
|
@ -1662,7 +1723,7 @@ button:focus {
|
|||
.chat-parent {
|
||||
/* Optimize for scrolling performance */
|
||||
will-change: scroll-position;
|
||||
contain: layout style paint;
|
||||
contain: style paint;
|
||||
|
||||
/* Ensure GPU acceleration */
|
||||
transform: translateZ(0);
|
||||
|
|
@ -1802,6 +1863,15 @@ table {
|
|||
border-collapse: collapse;
|
||||
}
|
||||
|
||||
.table-wrapper {
|
||||
overflow-x: auto;
|
||||
}
|
||||
|
||||
.message-body :is(td, th) {
|
||||
word-break: normal;
|
||||
overflow-wrap: normal;
|
||||
}
|
||||
|
||||
table, tr, td, th, thead {
|
||||
border: 0;
|
||||
}
|
||||
|
|
@ -1814,3 +1884,86 @@ tr + tr th { border-top: 1px solid; }
|
|||
|
||||
thead + tbody tr:first-child td,
|
||||
thead + tbody tr:first-child th { border-top: 1px solid; }
|
||||
|
||||
/* ------------------------------------------------
|
||||
Tools CheckboxGroup - vertical DragDrop-like style
|
||||
------------------------------------------------ */
|
||||
|
||||
/* "Refresh list" link in the Tools label */
|
||||
.tools-refresh-link {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
/* Checkbox list container */
|
||||
#tools-group {
|
||||
padding: 0 !important;
|
||||
border-width: 0 !important;
|
||||
background: transparent !important;
|
||||
min-height: 0 !important;
|
||||
}
|
||||
|
||||
#tools-group .wrap {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
flex-wrap: nowrap;
|
||||
gap: 4px;
|
||||
padding: 0;
|
||||
margin-top: var(--spacing-lg);
|
||||
max-height: 350px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
/* Pretty scrollbar for the tools list */
|
||||
#tools-group .wrap::-webkit-scrollbar {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
}
|
||||
|
||||
#tools-group .wrap::-webkit-scrollbar-track {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
#tools-group .wrap::-webkit-scrollbar-thumb,
|
||||
#tools-group .wrap::-webkit-scrollbar-thumb:hover {
|
||||
background: var(--neutral-300);
|
||||
border-radius: 30px;
|
||||
}
|
||||
|
||||
.dark #tools-group .wrap::-webkit-scrollbar-thumb,
|
||||
.dark #tools-group .wrap::-webkit-scrollbar-thumb:hover {
|
||||
background: rgb(255 255 255 / 6.25%);
|
||||
border-radius: 10px;
|
||||
}
|
||||
|
||||
#tools-group .wrap::-webkit-scrollbar-corner {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
/* Each checkbox item */
|
||||
#tools-group label {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
padding: 5px 8px;
|
||||
border-radius: var(--radius-sm, 4px);
|
||||
background: var(--block-background-fill);
|
||||
border: 1px solid var(--border-color-primary);
|
||||
color: var(--body-text-color);
|
||||
font-size: var(--input-text-size);
|
||||
font-weight: var(--input-text-weight);
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
transition: border-color 0.15s ease, background 0.15s ease;
|
||||
box-shadow: none;
|
||||
}
|
||||
|
||||
#tools-group label:hover {
|
||||
border-color: var(--input-border-color-focus);
|
||||
}
|
||||
|
||||
#tools-group label span {
|
||||
flex: 1;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -41,9 +41,6 @@ Options:
|
|||
* **cpu_memory**: Maximum CPU memory in GiB to use for CPU offloading via the accelerate library. Whatever doesn't fit in the GPU or CPU will go to a disk cache if the "disk" checkbox is enabled.
|
||||
* **compute_dtype**: Used when "load_in_4bit" is checked. I recommend leaving the default value.
|
||||
* **quant_type**: Used when "load_in_4bit" is checked. I recommend leaving the default value.
|
||||
* **alpha_value**: Used to extend the context length of a model with a minor loss in quality. I have measured 1.75 to be optimal for 1.5x context, and 2.5 for 2x context. That is, with alpha = 2.5 you can make a model with 4096 context length go to 8192 context length.
|
||||
* **rope_freq_base**: Originally another way to write "alpha_value", it ended up becoming a necessary parameter for some models like CodeLlama, which was fine-tuned with this set to 1000000 and hence needs to be loaded with it set to 1000000 as well.
|
||||
* **compress_pos_emb**: The first and original context-length extension method, discovered by [kaiokendev](https://kaiokendev.github.io/til). When set to 2, the context length is doubled, 3 and it's tripled, etc. It should only be used for models that have been fine-tuned with this parameter set to different than 1. For models that have not been tuned to have greater context length, alpha_value will lead to a smaller accuracy loss.
|
||||
* **attn_implementation**: Choose the attention implementation. Valid options: `sdpa`, `eager`, `flash_attention_2`. The default (`sdpa`) works well in most cases; `flash_attention_2` may be useful for training.
|
||||
* **cpu**: Loads the model in CPU mode using Pytorch. The model will be loaded in 32-bit precision, so a lot of RAM will be used. CPU inference with transformers is older than llama.cpp and it works, but it's a lot slower. Note: this parameter has a different interpretation in the llama.cpp loader (see above).
|
||||
* **load_in_8bit**: Load the model in 8-bit precision using bitsandbytes. The 8-bit kernel in that library has been optimized for training and not inference, so load_in_8bit is slower than load_in_4bit (but more accurate).
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ A LoRA is tied to a specific model architecture — a LoRA trained on Llama 3 8B
|
|||
|
||||
### Quick Start
|
||||
|
||||
1. Load your base model (no LoRAs loaded).
|
||||
1. Load your base model with the **Transformers** loader (no LoRAs loaded).
|
||||
2. Open the **Training** tab > **Train LoRA**.
|
||||
3. Pick a dataset and configure parameters (see [below](#parameters)).
|
||||
4. Click **Start LoRA Training** and monitor the [loss](#loss).
|
||||
|
|
|
|||
159
docs/Tool Calling Tutorial.md
Normal file
159
docs/Tool Calling Tutorial.md
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
## Supported models
|
||||
|
||||
The following models are supported:
|
||||
|
||||
- Qwen 3.5
|
||||
- GPT-OSS
|
||||
- Mistral Small / Devstral
|
||||
- DeepSeek V3
|
||||
- Kimi-K2
|
||||
- MiniMax-M2.5
|
||||
- GLM-5
|
||||
- Llama 4
|
||||
|
||||
Other models that output tool calls as JSON (inside XML tags, code blocks, or plain JSON) are also supported through a generic fallback parser.
|
||||
|
||||
## Tool calling in the UI
|
||||
|
||||
### 1. Load a model with tool-calling support
|
||||
|
||||
Load a model with tool-calling support from the Model tab.
|
||||
|
||||
### 2. Select tools
|
||||
|
||||
In the chat sidebar, check the tools you want the model to use:
|
||||
|
||||
- **web_search** -- Search the web using DuckDuckGo.
|
||||
- **fetch_webpage** -- Fetch the content of a URL.
|
||||
- **calculate** -- Evaluate math expressions.
|
||||
- **get_datetime** -- Get the current date and time.
|
||||
- **roll_dice** -- Roll dice.
|
||||
|
||||
### 3. Chat
|
||||
|
||||
Send a message as usual. When the model decides it needs a tool, it will call it automatically. You will see each tool call and its result in a collapsible accordion inside the chat message.
|
||||
|
||||
The model may call multiple tools in sequence before giving its final answer.
|
||||
|
||||
## Writing custom tools
|
||||
|
||||
Each tool is a single `.py` file in `user_data/tools/`. It needs two things:
|
||||
|
||||
1. A `tool` dictionary that describes the function (name, description, parameters).
|
||||
2. An `execute(arguments)` function that runs it and returns the result.
|
||||
|
||||
Here is a minimal example (`user_data/tools/get_datetime.py`):
|
||||
|
||||
```python
|
||||
from datetime import datetime
|
||||
|
||||
tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_datetime",
|
||||
"description": "Get the current date and time.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def execute(arguments):
|
||||
now = datetime.now()
|
||||
return {"date": now.strftime("%Y-%m-%d"), "time": now.strftime("%I:%M %p")}
|
||||
```
|
||||
|
||||
An example with parameters (`user_data/tools/roll_dice.py`):
|
||||
|
||||
```python
|
||||
import random
|
||||
|
||||
tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "roll_dice",
|
||||
"description": "Roll one or more dice with the specified number of sides.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {"type": "integer", "description": "Number of dice to roll.", "default": 1},
|
||||
"sides": {"type": "integer", "description": "Number of sides per die.", "default": 20},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def execute(arguments):
|
||||
count = max(1, min(arguments.get("count", 1), 1000))
|
||||
sides = max(2, min(arguments.get("sides", 20), 1000))
|
||||
rolls = [random.randint(1, sides) for _ in range(count)]
|
||||
return {"rolls": rolls, "total": sum(rolls)}
|
||||
```
|
||||
|
||||
You can open the built-in tools in `user_data/tools/` for more examples.
|
||||
|
||||
## Tool calling over the API
|
||||
|
||||
Tool calling over the API follows the [OpenAI API](https://platform.openai.com/docs/guides/function-calling) convention. Define your tools, send them with your messages, and handle tool calls in a loop until the model gives a final answer.
|
||||
|
||||
```python
|
||||
import json
|
||||
import requests
|
||||
|
||||
url = "http://127.0.0.1:5000/v1/chat/completions"
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather for a given location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "City name"},
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def execute_tool(name, arguments):
|
||||
if name == "get_weather":
|
||||
return {"temperature": "14°C", "condition": "partly cloudy"}
|
||||
return {"error": f"Unknown tool: {name}"}
|
||||
|
||||
|
||||
messages = [{"role": "user", "content": "What's the weather like in Paris?"}]
|
||||
|
||||
for _ in range(10):
|
||||
response = requests.post(url, json={"messages": messages, "tools": tools}).json()
|
||||
choice = response["choices"][0]
|
||||
|
||||
if choice["finish_reason"] == "tool_calls":
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": choice["message"]["content"],
|
||||
"tool_calls": choice["message"]["tool_calls"],
|
||||
})
|
||||
|
||||
for tool_call in choice["message"]["tool_calls"]:
|
||||
name = tool_call["function"]["name"]
|
||||
arguments = json.loads(tool_call["function"]["arguments"])
|
||||
result = execute_tool(name, arguments)
|
||||
print(f"Tool call: {name}({arguments}) => {result}")
|
||||
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call["id"],
|
||||
"content": json.dumps(result),
|
||||
})
|
||||
else:
|
||||
print(f"\nAssistant: {choice['message']['content']}")
|
||||
break
|
||||
```
|
||||
|
|
@ -11,8 +11,10 @@ from pydantic import ValidationError
|
|||
|
||||
from extensions.openai.errors import InvalidRequestError
|
||||
from extensions.openai.typing import ToolDefinition
|
||||
from extensions.openai.utils import debug_msg, getToolCallId, parseToolCall
|
||||
from extensions.openai.utils import debug_msg
|
||||
from modules.tool_parsing import get_tool_call_id, parse_tool_call, detect_tool_call_format
|
||||
from modules import shared
|
||||
from modules.reasoning import extract_reasoning
|
||||
from modules.chat import (
|
||||
generate_chat_prompt,
|
||||
generate_chat_reply,
|
||||
|
|
@ -37,17 +39,114 @@ def load_chat_template_file(filepath):
|
|||
return text
|
||||
|
||||
|
||||
def convert_logprobs_to_tiktoken(model, logprobs):
|
||||
# more problems than it's worth.
|
||||
# try:
|
||||
# encoder = tiktoken.encoding_for_model(model)
|
||||
# # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
|
||||
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
|
||||
# except KeyError:
|
||||
# # assume native tokens if we can't find the tokenizer
|
||||
# return logprobs
|
||||
def _get_raw_logprob_entries(offset=0):
|
||||
"""Get raw logprob entries from llama.cpp/ExLlamav3 backend, starting from offset.
|
||||
|
||||
return logprobs
|
||||
Returns (new_entries, new_offset).
|
||||
"""
|
||||
if not hasattr(shared.model, 'last_completion_probabilities') or not shared.model.last_completion_probabilities:
|
||||
return [], offset
|
||||
|
||||
all_entries = shared.model.last_completion_probabilities
|
||||
new_entries = all_entries[offset:]
|
||||
return new_entries, len(all_entries)
|
||||
|
||||
|
||||
def _dict_to_logprob_entries(token_dict):
|
||||
"""Convert a flat {token: logprob} dict (from LogprobProcessor) to raw entry format."""
|
||||
if not token_dict:
|
||||
return []
|
||||
|
||||
return [{"top_logprobs": [{"token": t, "logprob": lp} for t, lp in token_dict.items()]}]
|
||||
|
||||
|
||||
def _parse_entry_top(entry):
|
||||
"""Extract the top logprobs list from a raw entry, handling both key names."""
|
||||
return entry.get('top_logprobs', entry.get('top_probs', []))
|
||||
|
||||
|
||||
def format_chat_logprobs(entries):
|
||||
"""Format logprob entries into OpenAI chat completions logprobs format.
|
||||
|
||||
Output: {"content": [{"token", "logprob", "bytes", "top_logprobs": [...]}]}
|
||||
"""
|
||||
if not entries:
|
||||
return None
|
||||
|
||||
content = []
|
||||
for entry in entries:
|
||||
top = _parse_entry_top(entry)
|
||||
if not top:
|
||||
continue
|
||||
|
||||
chosen = top[0]
|
||||
token_str = chosen.get('token', '')
|
||||
token_logprob = chosen.get('logprob', chosen.get('prob', 0))
|
||||
|
||||
top_list = []
|
||||
for item in top:
|
||||
t = item.get('token', '')
|
||||
lp = item.get('logprob', item.get('prob', 0))
|
||||
top_list.append({
|
||||
"token": t,
|
||||
"logprob": lp,
|
||||
"bytes": list(t.encode('utf-8')) if t else None
|
||||
})
|
||||
|
||||
content.append({
|
||||
"token": token_str,
|
||||
"logprob": token_logprob,
|
||||
"bytes": list(token_str.encode('utf-8')) if token_str else None,
|
||||
"top_logprobs": top_list
|
||||
})
|
||||
|
||||
return {"content": content, "refusal": None} if content else None
|
||||
|
||||
|
||||
def format_completion_logprobs(entries):
|
||||
"""Format logprob entries into OpenAI completions logprobs format.
|
||||
|
||||
Output: {"tokens", "token_logprobs", "top_logprobs": [{token: prob}], "text_offset"}
|
||||
"""
|
||||
if not entries:
|
||||
return None
|
||||
|
||||
tokens = []
|
||||
token_logprobs = []
|
||||
top_logprobs = []
|
||||
text_offset = []
|
||||
offset = 0
|
||||
|
||||
for entry in entries:
|
||||
top = _parse_entry_top(entry)
|
||||
if not top:
|
||||
continue
|
||||
|
||||
chosen = top[0]
|
||||
token_str = chosen.get('token', '')
|
||||
token_logprob = chosen.get('logprob', chosen.get('prob', 0))
|
||||
|
||||
tokens.append(token_str)
|
||||
token_logprobs.append(token_logprob)
|
||||
text_offset.append(offset)
|
||||
offset += len(token_str)
|
||||
|
||||
top_dict = {}
|
||||
for item in top:
|
||||
t = item.get('token', '')
|
||||
lp = item.get('logprob', item.get('prob', 0))
|
||||
top_dict[t] = lp
|
||||
top_logprobs.append(top_dict)
|
||||
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
return {
|
||||
"tokens": tokens,
|
||||
"token_logprobs": token_logprobs,
|
||||
"top_logprobs": top_logprobs,
|
||||
"text_offset": text_offset
|
||||
}
|
||||
|
||||
|
||||
def process_parameters(body, is_legacy=False):
|
||||
|
|
@ -72,7 +171,16 @@ def process_parameters(body, is_legacy=False):
|
|||
elif isinstance(body['stop'], list):
|
||||
generate_params['custom_stopping_strings'] = body['stop']
|
||||
|
||||
if shared.args.loader != 'llama.cpp':
|
||||
# Resolve logprobs: for chat completions, logprobs is a bool and the count
|
||||
# comes from top_logprobs. Normalize to an int for all backends.
|
||||
logprobs = body.get('logprobs', None)
|
||||
top_logprobs = body.get('top_logprobs', None)
|
||||
if logprobs is True:
|
||||
logprobs = max(top_logprobs, 1) if top_logprobs is not None else 5
|
||||
generate_params['logprobs'] = logprobs
|
||||
|
||||
# For llama.cpp and ExLlamav3 native, logit_bias and logprobs are forwarded natively
|
||||
if shared.args.loader not in ('llama.cpp', 'ExLlamav3'):
|
||||
from transformers import LogitsProcessorList
|
||||
|
||||
from modules.transformers_loader import (
|
||||
|
|
@ -85,13 +193,9 @@ def process_parameters(body, is_legacy=False):
|
|||
if logit_bias: # {str: float, ...}
|
||||
logits_processor = [LogitsBiasProcessor(logit_bias)]
|
||||
|
||||
logprobs = None # coming to chat eventually
|
||||
if 'logprobs' in body:
|
||||
logprobs = body.get('logprobs', 0) # maybe cap at topk? don't clamp 0-5.
|
||||
if logprobs is not None and logprobs > 0:
|
||||
generate_params['logprob_proc'] = LogprobProcessor(logprobs)
|
||||
logits_processor.extend([generate_params['logprob_proc']])
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
if logits_processor: # requires logits_processor support
|
||||
generate_params['logits_processor'] = LogitsProcessorList(logits_processor)
|
||||
|
|
@ -137,12 +241,14 @@ def convert_history(history):
|
|||
user_input = ""
|
||||
user_input_last = True
|
||||
system_message = ""
|
||||
seen_non_system = False
|
||||
|
||||
for entry in history:
|
||||
content = entry["content"]
|
||||
role = entry["role"]
|
||||
|
||||
if role == "user":
|
||||
seen_non_system = True
|
||||
# Extract text content (images handled by model-specific code)
|
||||
content = process_multimodal_content(content)
|
||||
user_input = content
|
||||
|
|
@ -154,6 +260,7 @@ def convert_history(history):
|
|||
|
||||
current_message = content
|
||||
elif role == "assistant":
|
||||
seen_non_system = True
|
||||
meta = {}
|
||||
tool_calls = entry.get("tool_calls")
|
||||
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
|
||||
|
|
@ -170,13 +277,22 @@ def convert_history(history):
|
|||
else:
|
||||
chat_dialogue.append(['', current_reply, '', meta])
|
||||
elif role == "tool":
|
||||
seen_non_system = True
|
||||
user_input_last = False
|
||||
meta = {}
|
||||
if "tool_call_id" in entry:
|
||||
meta["tool_call_id"] = entry["tool_call_id"]
|
||||
chat_dialogue.append(['', '', content, meta])
|
||||
elif role == "system":
|
||||
system_message += f"\n{content}" if system_message else content
|
||||
elif role in ("system", "developer"):
|
||||
if not seen_non_system:
|
||||
# Leading system messages go to custom_system_message (placed at top)
|
||||
system_message += f"\n{content}" if system_message else content
|
||||
else:
|
||||
# Mid-conversation system messages: preserve position in history
|
||||
if current_message:
|
||||
chat_dialogue.append([current_message, '', '', {}])
|
||||
current_message = ""
|
||||
chat_dialogue.append([content, '', '', {"role": "system"}])
|
||||
|
||||
if not user_input_last:
|
||||
user_input = ""
|
||||
|
|
@ -202,6 +318,10 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
if 'tools' in body and body['tools'] is not None and isinstance(body['tools'], list) and len(body['tools']) > 0:
|
||||
tools = validateTools(body['tools']) # raises InvalidRequestError if validation fails
|
||||
|
||||
tool_choice = body.get('tool_choice', None)
|
||||
if tool_choice == "none":
|
||||
tools = None # Disable tool detection entirely
|
||||
|
||||
messages = body['messages']
|
||||
for m in messages:
|
||||
if 'role' not in m:
|
||||
|
|
@ -293,29 +413,55 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
|
||||
requested_model = generate_params.pop('model')
|
||||
logprob_proc = generate_params.pop('logprob_proc', None)
|
||||
if logprob_proc:
|
||||
logprob_proc.token_alternatives_history.clear()
|
||||
chat_logprobs_offset = [0] # mutable for closure access in streaming
|
||||
|
||||
def chat_streaming_chunk(content, chunk_tool_calls=None):
|
||||
def chat_streaming_chunk(content=None, chunk_tool_calls=None, include_role=False, reasoning_content=None):
|
||||
# begin streaming
|
||||
delta = {}
|
||||
if include_role:
|
||||
delta['role'] = 'assistant'
|
||||
delta['refusal'] = None
|
||||
if content is not None:
|
||||
delta['content'] = content
|
||||
if reasoning_content is not None:
|
||||
delta['reasoning_content'] = reasoning_content
|
||||
if chunk_tool_calls:
|
||||
delta['tool_calls'] = chunk_tool_calls
|
||||
|
||||
chunk = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
"system_fingerprint": None,
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
"delta": {'role': 'assistant', 'content': content, 'tool_calls': chunk_tool_calls},
|
||||
"delta": delta,
|
||||
"logprobs": None,
|
||||
}],
|
||||
}
|
||||
|
||||
if logprob_proc: # not official for chat yet
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
# else:
|
||||
# chunk[resp_list][0]["logprobs"] = None
|
||||
if logprob_proc:
|
||||
entries = _dict_to_logprob_entries(logprob_proc.token_alternatives)
|
||||
formatted = format_chat_logprobs(entries)
|
||||
if formatted:
|
||||
chunk[resp_list][0]["logprobs"] = formatted
|
||||
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
|
||||
entries, chat_logprobs_offset[0] = _get_raw_logprob_entries(chat_logprobs_offset[0])
|
||||
if entries:
|
||||
formatted = format_chat_logprobs(entries)
|
||||
if formatted:
|
||||
chunk[resp_list][0]["logprobs"] = formatted
|
||||
|
||||
return chunk
|
||||
|
||||
# Check if usage should be included in streaming chunks per OpenAI spec
|
||||
stream_options = body.get('stream_options')
|
||||
include_usage = bool(stream_options) and bool(stream_options.get('include_usage') if isinstance(stream_options, dict) else getattr(stream_options, 'include_usage', False))
|
||||
|
||||
# generate reply #######################################
|
||||
if prompt_only:
|
||||
prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_)
|
||||
|
|
@ -323,75 +469,133 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
return
|
||||
|
||||
if stream:
|
||||
yield chat_streaming_chunk('')
|
||||
chunk = chat_streaming_chunk('', include_role=True)
|
||||
if include_usage:
|
||||
chunk['usage'] = None
|
||||
yield chunk
|
||||
|
||||
generator = generate_chat_reply(
|
||||
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
|
||||
|
||||
answer = ''
|
||||
seen_content = ''
|
||||
seen_reasoning = ''
|
||||
|
||||
tool_calls = []
|
||||
end_last_tool_call = 0
|
||||
supported_tools = [x["function"]["name"] for x in tools] if tools is not None else None
|
||||
_tool_parsers = None
|
||||
|
||||
# Filter supported_tools when tool_choice specifies a particular function
|
||||
if supported_tools and isinstance(tool_choice, dict):
|
||||
specified_func = tool_choice.get("function", {}).get("name")
|
||||
if specified_func and specified_func in supported_tools:
|
||||
supported_tools = [specified_func]
|
||||
|
||||
if supported_tools is not None:
|
||||
_template_str = generate_params.get('instruction_template_str', '') if generate_params.get('mode') == 'instruct' else generate_params.get('chat_template_str', '')
|
||||
_tool_parsers, _, _ = detect_tool_call_format(_template_str)
|
||||
|
||||
for a in generator:
|
||||
answer = a['internal'][-1][1]
|
||||
|
||||
if supported_tools is not None:
|
||||
tool_call = parseToolCall(answer[end_last_tool_call:], supported_tools) if len(answer) > 0 else []
|
||||
tool_call = parse_tool_call(answer[end_last_tool_call:], supported_tools, parsers=_tool_parsers) if len(answer) > 0 else []
|
||||
if len(tool_call) > 0:
|
||||
for tc in tool_call:
|
||||
tc["id"] = getToolCallId()
|
||||
tc["index"] = len(tool_calls)
|
||||
tc["id"] = get_tool_call_id()
|
||||
if stream:
|
||||
tc["index"] = len(tool_calls)
|
||||
tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"])
|
||||
tool_calls.append(tc)
|
||||
end_last_tool_call = len(answer)
|
||||
|
||||
if stream:
|
||||
len_seen = len(seen_content)
|
||||
new_content = answer[len_seen:]
|
||||
|
||||
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||
continue
|
||||
|
||||
chunk = chat_streaming_chunk(new_content)
|
||||
|
||||
seen_content = answer
|
||||
yield chunk
|
||||
|
||||
# stop generation if tool_calls were generated previously
|
||||
# Stop generation before streaming content if tool_calls were detected,
|
||||
# so that raw tool markup is not sent as content deltas.
|
||||
if len(tool_calls) > 0:
|
||||
break
|
||||
|
||||
if stream:
|
||||
# Strip reasoning/thinking blocks so only final content is streamed.
|
||||
# Reasoning is emitted separately as reasoning_content deltas.
|
||||
reasoning, content = extract_reasoning(answer)
|
||||
if reasoning is not None:
|
||||
new_reasoning = reasoning[len(seen_reasoning):]
|
||||
new_content = content[len(seen_content):]
|
||||
else:
|
||||
new_reasoning = None
|
||||
new_content = answer[len(seen_content):]
|
||||
|
||||
if (not new_content and not new_reasoning) or chr(0xfffd) in (new_content or '') + (new_reasoning or ''):
|
||||
continue
|
||||
|
||||
chunk = chat_streaming_chunk(
|
||||
content=new_content if new_content else None,
|
||||
reasoning_content=new_reasoning if new_reasoning else None,
|
||||
)
|
||||
if include_usage:
|
||||
chunk['usage'] = None
|
||||
|
||||
if reasoning is not None:
|
||||
seen_reasoning = reasoning
|
||||
seen_content = content
|
||||
else:
|
||||
seen_content = answer
|
||||
yield chunk
|
||||
|
||||
token_count = shared.model.last_prompt_token_count if hasattr(shared.model, 'last_prompt_token_count') else 0
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if len(tool_calls) > 0:
|
||||
stop_reason = "tool_calls"
|
||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
|
||||
elif token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
|
||||
stop_reason = "length"
|
||||
else:
|
||||
stop_reason = "stop"
|
||||
|
||||
if stream:
|
||||
chunk = chat_streaming_chunk('', tool_calls)
|
||||
chunk = chat_streaming_chunk(chunk_tool_calls=tool_calls)
|
||||
chunk[resp_list][0]['finish_reason'] = stop_reason
|
||||
chunk['usage'] = {
|
||||
usage = {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
|
||||
yield chunk
|
||||
if include_usage:
|
||||
chunk['usage'] = None
|
||||
yield chunk
|
||||
# Separate usage-only chunk with choices: [] per OpenAI spec
|
||||
yield {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
"system_fingerprint": None,
|
||||
resp_list: [],
|
||||
"usage": usage
|
||||
}
|
||||
else:
|
||||
yield chunk
|
||||
else:
|
||||
reasoning, content = extract_reasoning(answer)
|
||||
message = {
|
||||
"role": "assistant",
|
||||
"refusal": None,
|
||||
"content": None if tool_calls else content,
|
||||
**({"reasoning_content": reasoning} if reasoning else {}),
|
||||
**({"tool_calls": tool_calls} if tool_calls else {}),
|
||||
}
|
||||
resp = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
"system_fingerprint": None,
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": stop_reason,
|
||||
"message": {"role": "assistant", "content": answer, **({"tool_calls": tool_calls} if tool_calls else {})},
|
||||
"message": message,
|
||||
"logprobs": None,
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": token_count,
|
||||
|
|
@ -399,11 +603,19 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
}
|
||||
if logprob_proc: # not official for chat yet
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
# else:
|
||||
# resp[resp_list][0]["logprobs"] = None
|
||||
if logprob_proc:
|
||||
all_entries = []
|
||||
for alt in logprob_proc.token_alternatives_history:
|
||||
all_entries.extend(_dict_to_logprob_entries(alt))
|
||||
formatted = format_chat_logprobs(all_entries)
|
||||
if formatted:
|
||||
resp[resp_list][0]["logprobs"] = formatted
|
||||
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
|
||||
raw = getattr(shared.model, 'last_completion_probabilities', None)
|
||||
if raw:
|
||||
formatted = format_chat_logprobs(raw)
|
||||
if formatted:
|
||||
resp[resp_list][0]["logprobs"] = formatted
|
||||
|
||||
yield resp
|
||||
|
||||
|
|
@ -411,7 +623,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_event=None):
|
||||
object_type = 'text_completion'
|
||||
created_time = int(time.time())
|
||||
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
|
||||
cmpl_id = "cmpl-%d" % (int(time.time() * 1000000000))
|
||||
resp_list = 'data' if is_legacy else 'choices'
|
||||
|
||||
prompt_str = 'context' if is_legacy else 'prompt'
|
||||
|
|
@ -445,6 +657,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
generate_params['stop_event'] = stop_event
|
||||
requested_model = generate_params.pop('model')
|
||||
logprob_proc = generate_params.pop('logprob_proc', None)
|
||||
if logprob_proc:
|
||||
logprob_proc.token_alternatives_history.clear()
|
||||
suffix = body['suffix'] if body['suffix'] else ''
|
||||
echo = body['echo']
|
||||
|
||||
|
|
@ -456,6 +670,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
logger.info(f"Found {len(raw_images)} image(s) in request.")
|
||||
generate_params['raw_images'] = raw_images
|
||||
|
||||
n_completions = body.get('n', 1) or 1
|
||||
|
||||
if not stream:
|
||||
prompt_arg = body[prompt_str]
|
||||
|
||||
|
|
@ -469,6 +685,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
resp_list_data = []
|
||||
total_completion_token_count = 0
|
||||
total_prompt_token_count = 0
|
||||
choice_index = 0
|
||||
|
||||
for idx, prompt in enumerate(prompt_arg, start=0):
|
||||
if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], int):
|
||||
|
|
@ -483,37 +700,59 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
prompt = decode(prompt)[0]
|
||||
|
||||
prefix = prompt if echo else ''
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||
answer = ''
|
||||
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
token_count = len(encode(prompt)[0])
|
||||
total_prompt_token_count += token_count
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
total_completion_token_count += completion_token_count
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
||||
respi = {
|
||||
"index": idx,
|
||||
"finish_reason": stop_reason,
|
||||
"text": prefix + answer + suffix,
|
||||
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
||||
}
|
||||
original_seed = generate_params.get('seed', -1)
|
||||
for _n in range(n_completions):
|
||||
# Increment seed for each completion to ensure diversity (matches llama.cpp native behavior)
|
||||
if original_seed >= 0:
|
||||
generate_params['seed'] = original_seed + _n
|
||||
|
||||
resp_list_data.extend([respi])
|
||||
if logprob_proc:
|
||||
logprob_proc.token_alternatives_history.clear()
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||
answer = ''
|
||||
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
total_completion_token_count += completion_token_count
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
||||
if logprob_proc:
|
||||
all_entries = []
|
||||
for alt in logprob_proc.token_alternatives_history:
|
||||
all_entries.extend(_dict_to_logprob_entries(alt))
|
||||
completion_logprobs = format_completion_logprobs(all_entries)
|
||||
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
|
||||
raw = getattr(shared.model, 'last_completion_probabilities', None)
|
||||
completion_logprobs = format_completion_logprobs(raw)
|
||||
else:
|
||||
completion_logprobs = None
|
||||
|
||||
respi = {
|
||||
"index": choice_index,
|
||||
"finish_reason": stop_reason,
|
||||
"text": prefix + answer + suffix,
|
||||
"logprobs": completion_logprobs,
|
||||
}
|
||||
|
||||
resp_list_data.append(respi)
|
||||
choice_index += 1
|
||||
|
||||
resp = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
"system_fingerprint": None,
|
||||
resp_list: resp_list_data,
|
||||
"usage": {
|
||||
"prompt_tokens": total_prompt_token_count,
|
||||
|
|
@ -538,24 +777,41 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
prefix = prompt if echo else ''
|
||||
token_count = len(encode(prompt)[0])
|
||||
|
||||
# Check if usage should be included in streaming chunks per OpenAI spec
|
||||
stream_options = body.get('stream_options')
|
||||
include_usage = bool(stream_options) and bool(stream_options.get('include_usage') if isinstance(stream_options, dict) else getattr(stream_options, 'include_usage', False))
|
||||
cmpl_logprobs_offset = [0] # mutable for closure access in streaming
|
||||
|
||||
def text_streaming_chunk(content):
|
||||
# begin streaming
|
||||
if logprob_proc:
|
||||
chunk_logprobs = format_completion_logprobs(_dict_to_logprob_entries(logprob_proc.token_alternatives))
|
||||
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
|
||||
entries, cmpl_logprobs_offset[0] = _get_raw_logprob_entries(cmpl_logprobs_offset[0])
|
||||
chunk_logprobs = format_completion_logprobs(entries) if entries else None
|
||||
else:
|
||||
chunk_logprobs = None
|
||||
|
||||
chunk = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
"system_fingerprint": None,
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
"text": content,
|
||||
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
||||
"logprobs": chunk_logprobs,
|
||||
}],
|
||||
}
|
||||
|
||||
return chunk
|
||||
|
||||
yield text_streaming_chunk(prefix)
|
||||
chunk = text_streaming_chunk(prefix)
|
||||
if include_usage:
|
||||
chunk['usage'] = None
|
||||
yield chunk
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
|
|
@ -575,6 +831,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
|
||||
seen_content = answer
|
||||
chunk = text_streaming_chunk(new_content)
|
||||
if include_usage:
|
||||
chunk['usage'] = None
|
||||
yield chunk
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
|
|
@ -584,13 +842,27 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
|
||||
chunk = text_streaming_chunk(suffix)
|
||||
chunk[resp_list][0]["finish_reason"] = stop_reason
|
||||
chunk["usage"] = {
|
||||
usage = {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
|
||||
yield chunk
|
||||
if include_usage:
|
||||
chunk['usage'] = None
|
||||
yield chunk
|
||||
# Separate usage-only chunk with choices: [] per OpenAI spec
|
||||
yield {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
"system_fingerprint": None,
|
||||
resp_list: [],
|
||||
"usage": usage
|
||||
}
|
||||
else:
|
||||
yield chunk
|
||||
|
||||
|
||||
def chat_completions(body: dict, is_legacy: bool = False, stop_event=None) -> dict:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from modules import shared, ui
|
||||
from modules import loaders, shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model, unload_model
|
||||
|
|
@ -20,10 +20,14 @@ def list_models():
|
|||
|
||||
def list_models_openai_format():
|
||||
"""Returns model list in OpenAI API format"""
|
||||
model_names = get_available_models()
|
||||
if shared.model_name and shared.model_name != 'None':
|
||||
data = [model_info_dict(shared.model_name)]
|
||||
else:
|
||||
data = []
|
||||
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [model_info_dict(name) for name in model_names]
|
||||
"data": data
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -50,7 +54,7 @@ def _load_model(data):
|
|||
# parameters exposed in the UI. Never allow security-sensitive
|
||||
# flags like trust_remote_code or extra_flags to be set via the API.
|
||||
blocked_keys = {'extra_flags'}
|
||||
allowed_keys = set(ui.list_model_elements()) - blocked_keys
|
||||
allowed_keys = set(loaders.list_model_elements()) - blocked_keys
|
||||
if args:
|
||||
for k in args:
|
||||
if k in allowed_keys and hasattr(shared.args, k):
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import extensions.openai.completions as OAIcompletions
|
|||
import extensions.openai.logits as OAIlogits
|
||||
import extensions.openai.models as OAImodels
|
||||
from extensions.openai.tokens import token_count, token_decode, token_encode
|
||||
from extensions.openai.errors import OpenAIError
|
||||
from extensions.openai.utils import _start_cloudflared
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
|
|
@ -94,6 +95,20 @@ app.add_middleware(
|
|||
)
|
||||
|
||||
|
||||
@app.exception_handler(OpenAIError)
|
||||
async def openai_error_handler(request: Request, exc: OpenAIError):
|
||||
error_type = "server_error" if exc.code >= 500 else "invalid_request_error"
|
||||
return JSONResponse(
|
||||
status_code=exc.code,
|
||||
content={"error": {
|
||||
"message": exc.message,
|
||||
"type": error_type,
|
||||
"param": getattr(exc, 'param', None),
|
||||
"code": None
|
||||
}}
|
||||
)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def validate_host_header(request: Request, call_next):
|
||||
# Be strict about only approving access to localhost by default
|
||||
|
|
@ -119,6 +134,12 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
|
|||
is_legacy = "/generate" in path
|
||||
|
||||
if request_data.stream:
|
||||
if (request_data.n or 1) > 1:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"error": {"message": "n > 1 is not supported with streaming.", "type": "invalid_request_error", "param": "n", "code": None}}
|
||||
)
|
||||
|
||||
stop_event = threading.Event()
|
||||
|
||||
async def generator():
|
||||
|
|
@ -130,6 +151,8 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
|
|||
break
|
||||
|
||||
yield {"data": json.dumps(resp)}
|
||||
|
||||
yield {"data": "[DONE]"}
|
||||
finally:
|
||||
stop_event.set()
|
||||
response.close()
|
||||
|
|
@ -170,6 +193,8 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
|
|||
break
|
||||
|
||||
yield {"data": json.dumps(resp)}
|
||||
|
||||
yield {"data": "[DONE]"}
|
||||
finally:
|
||||
stop_event.set()
|
||||
response.close()
|
||||
|
|
@ -433,10 +458,13 @@ def run_server():
|
|||
|
||||
# In the server configuration:
|
||||
server_addrs = []
|
||||
if os.environ.get('OPENEDAI_ENABLE_IPV6', shared.args.api_enable_ipv6):
|
||||
server_addrs.append('[::]' if shared.args.listen else '[::1]')
|
||||
if not os.environ.get('OPENEDAI_DISABLE_IPV4', shared.args.api_disable_ipv4):
|
||||
server_addrs.append('0.0.0.0' if shared.args.listen else '127.0.0.1')
|
||||
if shared.args.listen and shared.args.listen_host:
|
||||
server_addrs.append(shared.args.listen_host)
|
||||
else:
|
||||
if os.environ.get('OPENEDAI_ENABLE_IPV6', shared.args.api_enable_ipv6):
|
||||
server_addrs.append('[::]' if shared.args.listen else '[::1]')
|
||||
if not os.environ.get('OPENEDAI_DISABLE_IPV4', shared.args.api_disable_ipv4):
|
||||
server_addrs.append('0.0.0.0' if shared.args.listen else '127.0.0.1')
|
||||
|
||||
if not server_addrs:
|
||||
raise Exception('you MUST enable IPv6 or IPv4 for the API to work')
|
||||
|
|
@ -447,11 +475,11 @@ def run_server():
|
|||
port,
|
||||
shared.args.public_api_id,
|
||||
max_attempts=3,
|
||||
on_start=lambda url: logger.info(f'OpenAI-compatible API URL:\n\n{url}\n')
|
||||
on_start=lambda url: logger.info(f'OpenAI-compatible API URL:\n\n{url}/v1\n')
|
||||
)
|
||||
else:
|
||||
url_proto = 'https://' if (ssl_certfile and ssl_keyfile) else 'http://'
|
||||
urls = [f'{url_proto}{addr}:{port}' for addr in server_addrs]
|
||||
urls = [f'{url_proto}{addr}:{port}/v1' for addr in server_addrs]
|
||||
if len(urls) > 1:
|
||||
logger.info('OpenAI-compatible API URLs:\n\n' + '\n'.join(urls) + '\n')
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -99,6 +99,10 @@ class ToolCall(BaseModel):
|
|||
function: FunctionCall
|
||||
|
||||
|
||||
class StreamOptions(BaseModel):
|
||||
include_usage: bool | None = False
|
||||
|
||||
|
||||
class CompletionRequestParams(BaseModel):
|
||||
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
|
||||
prompt: str | List[str] | None = Field(default=None, description="Text prompt for completion. Can also use 'messages' format for multimodal.")
|
||||
|
|
@ -109,10 +113,11 @@ class CompletionRequestParams(BaseModel):
|
|||
logit_bias: dict | None = None
|
||||
logprobs: int | None = None
|
||||
max_tokens: int | None = 512
|
||||
n: int | None = Field(default=1, description="Unused parameter.")
|
||||
n: int | None = Field(default=1, description="Number of completions to generate. Only supported without streaming.")
|
||||
presence_penalty: float | None = shared.args.presence_penalty
|
||||
stop: str | List[str] | None = None
|
||||
stream: bool | None = False
|
||||
stream_options: StreamOptions | None = None
|
||||
suffix: str | None = None
|
||||
temperature: float | None = shared.args.temperature
|
||||
top_p: float | None = shared.args.top_p
|
||||
|
|
@ -145,16 +150,27 @@ class ChatCompletionRequestParams(BaseModel):
|
|||
function_call: str | dict | None = Field(default=None, description="Unused parameter.")
|
||||
functions: List[dict] | None = Field(default=None, description="Unused parameter.")
|
||||
tools: List[dict] | None = Field(default=None, description="Tools signatures passed via MCP.")
|
||||
tool_choice: str | dict | None = Field(default=None, description="Controls tool use: 'auto', 'none', 'required', or {\"type\": \"function\", \"function\": {\"name\": \"...\"}}.")
|
||||
logit_bias: dict | None = None
|
||||
logprobs: bool | None = None
|
||||
top_logprobs: int | None = None
|
||||
max_tokens: int | None = None
|
||||
max_completion_tokens: int | None = None
|
||||
n: int | None = Field(default=1, description="Unused parameter.")
|
||||
presence_penalty: float | None = shared.args.presence_penalty
|
||||
stop: str | List[str] | None = None
|
||||
stream: bool | None = False
|
||||
stream_options: StreamOptions | None = None
|
||||
temperature: float | None = shared.args.temperature
|
||||
top_p: float | None = shared.args.top_p
|
||||
user: str | None = Field(default=None, description="Unused parameter.")
|
||||
|
||||
@model_validator(mode='after')
|
||||
def resolve_max_tokens(self):
|
||||
if self.max_tokens is None and self.max_completion_tokens is not None:
|
||||
self.max_tokens = self.max_completion_tokens
|
||||
return self
|
||||
|
||||
mode: str = Field(default='instruct', description="Valid options: instruct, chat, chat-instruct.")
|
||||
|
||||
instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/user_data/instruction-templates. If not set, the correct template will be automatically obtained from the model metadata.")
|
||||
|
|
|
|||
|
|
@ -1,8 +1,5 @@
|
|||
import base64
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
|
|
@ -55,473 +52,3 @@ def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_star
|
|||
time.sleep(3)
|
||||
|
||||
raise Exception('Could not start cloudflared.')
|
||||
|
||||
|
||||
def getToolCallId() -> str:
|
||||
letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b = [random.choice(letter_bytes) for _ in range(8)]
|
||||
return "call_" + "".join(b).lower()
|
||||
|
||||
|
||||
def checkAndSanitizeToolCallCandidate(candidate_dict: dict, tool_names: list[str]):
|
||||
# check if property 'function' exists and is a dictionary, otherwise adapt dict
|
||||
if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str):
|
||||
candidate_dict = {"type": "function", "function": candidate_dict}
|
||||
if 'function' in candidate_dict and isinstance(candidate_dict['function'], str):
|
||||
candidate_dict['name'] = candidate_dict['function']
|
||||
del candidate_dict['function']
|
||||
candidate_dict = {"type": "function", "function": candidate_dict}
|
||||
if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict):
|
||||
# check if 'name' exists within 'function' and is part of known tools
|
||||
if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names:
|
||||
candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value
|
||||
# map property 'parameters' used by some older models to 'arguments'
|
||||
if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]:
|
||||
candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"]
|
||||
del candidate_dict["function"]["parameters"]
|
||||
return candidate_dict
|
||||
return None
|
||||
|
||||
|
||||
def _extractBalancedJson(text: str, start: int) -> str | None:
|
||||
"""Extract a balanced JSON object from text starting at the given position.
|
||||
|
||||
Walks through the string tracking brace depth and string boundaries
|
||||
to correctly handle arbitrary nesting levels.
|
||||
"""
|
||||
if start >= len(text) or text[start] != '{':
|
||||
return None
|
||||
depth = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
for i in range(start, len(text)):
|
||||
c = text[i]
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
continue
|
||||
if c == '\\' and in_string:
|
||||
escape_next = True
|
||||
continue
|
||||
if c == '"':
|
||||
in_string = not in_string
|
||||
continue
|
||||
if in_string:
|
||||
continue
|
||||
if c == '{':
|
||||
depth += 1
|
||||
elif c == '}':
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start:i + 1]
|
||||
return None
|
||||
|
||||
|
||||
def _parseChannelToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse channel-based tool calls used by GPT-OSS and similar models.
|
||||
|
||||
Format:
|
||||
<|channel|>commentary to=functions.func_name <|constrain|>json<|message|>{"arg": "value"}
|
||||
"""
|
||||
matches = []
|
||||
for m in re.finditer(
|
||||
r'<\|channel\|>commentary to=functions\.([^<\s]+)\s*(?:<\|constrain\|>json)?<\|message\|>',
|
||||
answer
|
||||
):
|
||||
func_name = m.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
json_str = _extractBalancedJson(answer, m.end())
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return matches
|
||||
|
||||
|
||||
def _parseBareNameToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse bare function-name style tool calls used by Mistral and similar models.
|
||||
|
||||
Format:
|
||||
functionName{"arg": "value"}
|
||||
Multiple calls are concatenated directly or separated by whitespace.
|
||||
"""
|
||||
matches = []
|
||||
# Match tool name followed by opening brace, then extract balanced JSON
|
||||
escaped_names = [re.escape(name) for name in tool_names]
|
||||
pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{'
|
||||
for match in re.finditer(pattern, answer):
|
||||
text = match.group(0)
|
||||
name = None
|
||||
for n in tool_names:
|
||||
if text.startswith(n):
|
||||
name = n
|
||||
break
|
||||
if not name:
|
||||
continue
|
||||
brace_start = match.end() - 1
|
||||
json_str = _extractBalancedJson(answer, brace_start)
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return matches
|
||||
|
||||
|
||||
def _parseXmlParamToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse XML-parameter style tool calls used by Qwen3.5 and similar models.
|
||||
|
||||
Format:
|
||||
<tool_call>
|
||||
<function=function_name>
|
||||
<parameter=param_name>value</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
"""
|
||||
matches = []
|
||||
for tc_match in re.finditer(r'<tool_call>\s*(.*?)\s*</tool_call>', answer, re.DOTALL):
|
||||
tc_content = tc_match.group(1)
|
||||
func_match = re.search(r'<function=([^>]+)>', tc_content)
|
||||
if not func_match:
|
||||
continue
|
||||
func_name = func_match.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
arguments = {}
|
||||
for param_match in re.finditer(r'<parameter=([^>]+)>\s*(.*?)\s*</parameter>', tc_content, re.DOTALL):
|
||||
param_name = param_match.group(1).strip()
|
||||
param_value = param_match.group(2).strip()
|
||||
try:
|
||||
param_value = json.loads(param_value)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass # keep as string
|
||||
arguments[param_name] = param_value
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
return matches
|
||||
|
||||
|
||||
def _parseKimiToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse Kimi-K2-style tool calls using pipe-delimited tokens.
|
||||
|
||||
Format:
|
||||
<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|>functions.func_name:index<|tool_call_argument_begin|>{"arg": "value"}<|tool_call_end|>
|
||||
<|tool_calls_section_end|>
|
||||
"""
|
||||
matches = []
|
||||
for m in re.finditer(
|
||||
r'<\|tool_call_begin\|>\s*(?:functions\.)?(\S+?)(?::\d+)?\s*<\|tool_call_argument_begin\|>\s*',
|
||||
answer
|
||||
):
|
||||
func_name = m.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
json_str = _extractBalancedJson(answer, m.end())
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return matches
|
||||
|
||||
|
||||
def _parseMiniMaxToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse MiniMax-style tool calls using invoke/parameter XML tags.
|
||||
|
||||
Format:
|
||||
<minimax:tool_call>
|
||||
<invoke name="function_name">
|
||||
<parameter name="param_name">value</parameter>
|
||||
</invoke>
|
||||
</minimax:tool_call>
|
||||
"""
|
||||
matches = []
|
||||
for tc_match in re.finditer(r'<minimax:tool_call>\s*(.*?)\s*</minimax:tool_call>', answer, re.DOTALL):
|
||||
tc_content = tc_match.group(1)
|
||||
# Split on <invoke> to handle multiple parallel calls in one block
|
||||
for invoke_match in re.finditer(r'<invoke\s+name="([^"]+)">(.*?)</invoke>', tc_content, re.DOTALL):
|
||||
func_name = invoke_match.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
invoke_body = invoke_match.group(2)
|
||||
arguments = {}
|
||||
for param_match in re.finditer(r'<parameter\s+name="([^"]+)">\s*(.*?)\s*</parameter>', invoke_body, re.DOTALL):
|
||||
param_name = param_match.group(1).strip()
|
||||
param_value = param_match.group(2).strip()
|
||||
try:
|
||||
param_value = json.loads(param_value)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass # keep as string
|
||||
arguments[param_name] = param_value
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
return matches
|
||||
|
||||
|
||||
def _parseDeepSeekToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse DeepSeek-style tool calls using fullwidth Unicode token delimiters.
|
||||
|
||||
Format:
|
||||
<|tool▁calls▁begin|><|tool▁call▁begin|>func_name<|tool▁sep|>{"arg": "value"}<|tool▁call▁end|><|tool▁calls▁end|>
|
||||
"""
|
||||
matches = []
|
||||
for m in re.finditer(
|
||||
r'<|tool▁call▁begin|>\s*(\S+?)\s*<|tool▁sep|>\s*',
|
||||
answer
|
||||
):
|
||||
func_name = m.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
json_str = _extractBalancedJson(answer, m.end())
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return matches
|
||||
|
||||
|
||||
def _parseGlmToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse GLM-style tool calls using arg_key/arg_value XML pairs.
|
||||
|
||||
Format:
|
||||
<tool_call>function_name
|
||||
<arg_key>key1</arg_key>
|
||||
<arg_value>value1</arg_value>
|
||||
</tool_call>
|
||||
"""
|
||||
matches = []
|
||||
for tc_match in re.finditer(r'<tool_call>\s*(.*?)\s*</tool_call>', answer, re.DOTALL):
|
||||
tc_content = tc_match.group(1)
|
||||
# First non-tag text is the function name
|
||||
name_match = re.match(r'([^<\s]+)', tc_content.strip())
|
||||
if not name_match:
|
||||
continue
|
||||
func_name = name_match.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
# Extract arg_key/arg_value pairs
|
||||
keys = [k.group(1).strip() for k in re.finditer(r'<arg_key>\s*(.*?)\s*</arg_key>', tc_content, re.DOTALL)]
|
||||
vals = [v.group(1).strip() for v in re.finditer(r'<arg_value>\s*(.*?)\s*</arg_value>', tc_content, re.DOTALL)]
|
||||
if len(keys) != len(vals):
|
||||
continue
|
||||
arguments = {}
|
||||
for k, v in zip(keys, vals):
|
||||
try:
|
||||
v = json.loads(v)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass # keep as string
|
||||
arguments[k] = v
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
return matches
|
||||
|
||||
|
||||
def _parsePythonicToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse pythonic-style tool calls used by Llama 4 and similar models.
|
||||
|
||||
Format:
|
||||
[func_name(param1="value1", param2="value2"), func_name2(...)]
|
||||
"""
|
||||
matches = []
|
||||
# Match a bracketed list of function calls
|
||||
bracket_match = re.search(r'\[([^\[\]]+)\]', answer)
|
||||
if not bracket_match:
|
||||
return matches
|
||||
|
||||
inner = bracket_match.group(1)
|
||||
|
||||
# Build pattern for known tool names
|
||||
escaped_names = [re.escape(name) for name in tool_names]
|
||||
name_pattern = '|'.join(escaped_names)
|
||||
|
||||
for call_match in re.finditer(
|
||||
r'(' + name_pattern + r')\(([^)]*)\)',
|
||||
inner
|
||||
):
|
||||
func_name = call_match.group(1)
|
||||
params_str = call_match.group(2).strip()
|
||||
arguments = {}
|
||||
|
||||
if params_str:
|
||||
# Parse key="value" pairs, handling commas inside quoted values
|
||||
for param_match in re.finditer(
|
||||
r'(\w+)\s*=\s*("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\)]+)',
|
||||
params_str
|
||||
):
|
||||
param_name = param_match.group(1)
|
||||
param_value = param_match.group(2).strip()
|
||||
# Strip surrounding quotes
|
||||
if (param_value.startswith('"') and param_value.endswith('"')) or \
|
||||
(param_value.startswith("'") and param_value.endswith("'")):
|
||||
param_value = param_value[1:-1]
|
||||
# Try to parse as JSON for numeric/bool/null values
|
||||
try:
|
||||
param_value = json.loads(param_value)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
arguments[param_name] = param_value
|
||||
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def parseToolCall(answer: str, tool_names: list[str]):
|
||||
matches = []
|
||||
|
||||
# abort on very short answers to save computation cycles
|
||||
if len(answer) < 10:
|
||||
return matches
|
||||
|
||||
# Check for DeepSeek-style tool calls (fullwidth Unicode token delimiters)
|
||||
matches = _parseDeepSeekToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return matches
|
||||
|
||||
# Check for Kimi-K2-style tool calls (pipe-delimited tokens)
|
||||
matches = _parseKimiToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return matches
|
||||
|
||||
# Check for channel-based tool calls (e.g. GPT-OSS format)
|
||||
matches = _parseChannelToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return matches
|
||||
|
||||
# Check for MiniMax-style tool calls (invoke/parameter XML tags)
|
||||
matches = _parseMiniMaxToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return matches
|
||||
|
||||
# Check for GLM-style tool calls (arg_key/arg_value XML pairs)
|
||||
matches = _parseGlmToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return matches
|
||||
|
||||
# Check for XML-parameter style tool calls (e.g. Qwen3.5 format)
|
||||
matches = _parseXmlParamToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return matches
|
||||
|
||||
# Check for bare function-name style tool calls (e.g. Mistral format)
|
||||
matches = _parseBareNameToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return matches
|
||||
|
||||
# Check for pythonic-style tool calls (e.g. Llama 4 format)
|
||||
matches = _parsePythonicToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return matches
|
||||
|
||||
# Define the regex pattern to find the JSON content wrapped in <function>, <tools>, <tool_call>, and other tags observed from various models
|
||||
patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)</\1>"]
|
||||
|
||||
for pattern in patterns:
|
||||
for match in re.finditer(pattern, answer, re.DOTALL):
|
||||
# print(match.group(2))
|
||||
if match.group(2) is None:
|
||||
continue
|
||||
# remove backtick wraps if present
|
||||
candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip())
|
||||
candidate = re.sub(r"```$", "", candidate.strip())
|
||||
# unwrap inner tags
|
||||
candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL)
|
||||
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
|
||||
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
|
||||
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
|
||||
if not candidate.strip().startswith("["):
|
||||
candidate = "[" + candidate + "]"
|
||||
|
||||
candidates = []
|
||||
try:
|
||||
# parse the candidate JSON into a dictionary
|
||||
candidates = json.loads(candidate)
|
||||
if not isinstance(candidates, list):
|
||||
candidates = [candidates]
|
||||
except json.JSONDecodeError:
|
||||
# Ignore invalid JSON silently
|
||||
continue
|
||||
|
||||
for candidate_dict in candidates:
|
||||
checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names)
|
||||
if checked_candidate is not None:
|
||||
matches.append(checked_candidate)
|
||||
|
||||
# last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags
|
||||
if len(matches) == 0:
|
||||
try:
|
||||
candidate = answer
|
||||
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
|
||||
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
|
||||
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
|
||||
if not candidate.strip().startswith("["):
|
||||
candidate = "[" + candidate + "]"
|
||||
# parse the candidate JSON into a dictionary
|
||||
candidates = json.loads(candidate)
|
||||
if not isinstance(candidates, list):
|
||||
candidates = [candidates]
|
||||
for candidate_dict in candidates:
|
||||
checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names)
|
||||
if checked_candidate is not None:
|
||||
matches.append(checked_candidate)
|
||||
except json.JSONDecodeError:
|
||||
# Ignore invalid JSON silently
|
||||
pass
|
||||
|
||||
return matches
|
||||
|
|
|
|||
|
|
@ -269,7 +269,49 @@ function removeLastClick() {
|
|||
document.getElementById("Remove-last").click();
|
||||
}
|
||||
|
||||
function autoScrollToBottom() {
|
||||
if (!window.isScrolled) {
|
||||
const chatParent = document.getElementById("chat")?.parentNode?.parentNode?.parentNode;
|
||||
if (chatParent) {
|
||||
const maxScroll = chatParent.scrollHeight - chatParent.clientHeight;
|
||||
if (maxScroll > 0 && chatParent.scrollTop < maxScroll - 1) {
|
||||
chatParent.scrollTop = maxScroll;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function updateInstructPadding() {
|
||||
const chatElement = document.getElementById("chat");
|
||||
if (chatElement && chatElement.getAttribute("data-mode") === "instruct") {
|
||||
const messagesContainer = chatElement.querySelector(".messages");
|
||||
const lastChild = messagesContainer?.lastElementChild;
|
||||
const prevSibling = lastChild?.previousElementSibling;
|
||||
if (lastChild && prevSibling && chatElement.offsetHeight > 0) {
|
||||
let bufferHeight = Math.max(0, Math.max(window.innerHeight - 128 - 84, window.innerHeight - prevSibling.offsetHeight - 84) - lastChild.offsetHeight);
|
||||
if (window.innerWidth <= 924) {
|
||||
bufferHeight = Math.max(0, bufferHeight - 32);
|
||||
}
|
||||
messagesContainer.style.paddingBottom = `${bufferHeight}px`;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let pendingMorphdomData = null;
|
||||
let morphdomRafId = null;
|
||||
|
||||
function handleMorphdomUpdate(data) {
|
||||
pendingMorphdomData = data;
|
||||
if (!morphdomRafId) {
|
||||
morphdomRafId = requestAnimationFrame(() => {
|
||||
morphdomRafId = null;
|
||||
applyMorphdomUpdate(pendingMorphdomData);
|
||||
pendingMorphdomData = null;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function applyMorphdomUpdate(data) {
|
||||
// Determine target element and use it as query scope
|
||||
var target_element, target_html;
|
||||
if (data.last_message_only) {
|
||||
|
|
@ -283,27 +325,21 @@ function handleMorphdomUpdate(data) {
|
|||
|
||||
const queryScope = target_element;
|
||||
|
||||
// Track open blocks
|
||||
// Track open blocks and store their scroll positions
|
||||
const openBlocks = new Set();
|
||||
const scrollPositions = {};
|
||||
queryScope.querySelectorAll(".thinking-block").forEach(block => {
|
||||
const blockId = block.getAttribute("data-block-id");
|
||||
// If block exists and is open, add to open set
|
||||
if (blockId && block.hasAttribute("open")) {
|
||||
openBlocks.add(blockId);
|
||||
}
|
||||
});
|
||||
|
||||
// Store scroll positions for any open blocks
|
||||
const scrollPositions = {};
|
||||
queryScope.querySelectorAll(".thinking-block[open]").forEach(block => {
|
||||
const content = block.querySelector(".thinking-content");
|
||||
const blockId = block.getAttribute("data-block-id");
|
||||
if (content && blockId) {
|
||||
const isAtBottom = Math.abs((content.scrollHeight - content.scrollTop) - content.clientHeight) < 5;
|
||||
scrollPositions[blockId] = {
|
||||
position: content.scrollTop,
|
||||
isAtBottom: isAtBottom
|
||||
};
|
||||
const content = block.querySelector(".thinking-content");
|
||||
if (content) {
|
||||
const isAtBottom = Math.abs((content.scrollHeight - content.scrollTop) - content.clientHeight) < 5;
|
||||
scrollPositions[blockId] = {
|
||||
position: content.scrollTop,
|
||||
isAtBottom: isAtBottom
|
||||
};
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -313,8 +349,8 @@ function handleMorphdomUpdate(data) {
|
|||
{
|
||||
onBeforeElUpdated: function(fromEl, toEl) {
|
||||
// Preserve code highlighting
|
||||
if (fromEl.tagName === "PRE" && fromEl.querySelector("code[data-highlighted]")) {
|
||||
const fromCode = fromEl.querySelector("code");
|
||||
if (fromEl.tagName === "PRE") {
|
||||
const fromCode = fromEl.querySelector("code[data-highlighted]");
|
||||
const toCode = toEl.querySelector("code");
|
||||
|
||||
if (fromCode && toCode && fromCode.textContent === toCode.textContent) {
|
||||
|
|
@ -359,10 +395,23 @@ function handleMorphdomUpdate(data) {
|
|||
}
|
||||
);
|
||||
|
||||
// Syntax highlighting and LaTeX
|
||||
if (window.doSyntaxHighlighting) {
|
||||
window.doSyntaxHighlighting();
|
||||
}
|
||||
|
||||
// Auto-scroll runs both before and after padding update.
|
||||
// Before: so content growth isn't hidden by padding absorption.
|
||||
// After: so padding-added space is also scrolled into view.
|
||||
autoScrollToBottom();
|
||||
updateInstructPadding();
|
||||
autoScrollToBottom();
|
||||
|
||||
// Add toggle listeners for new blocks
|
||||
queryScope.querySelectorAll(".thinking-block").forEach(block => {
|
||||
if (!block._hasToggleListener) {
|
||||
block.addEventListener("toggle", function(e) {
|
||||
const wasScrolled = window.isScrolled;
|
||||
if (this.open) {
|
||||
const content = this.querySelector(".thinking-content");
|
||||
if (content) {
|
||||
|
|
@ -371,6 +420,12 @@ function handleMorphdomUpdate(data) {
|
|||
}, 0);
|
||||
}
|
||||
}
|
||||
autoScrollToBottom();
|
||||
updateInstructPadding();
|
||||
autoScrollToBottom();
|
||||
// Restore scroll state so the browser's layout adjustment
|
||||
// from the toggle doesn't disable auto-scroll
|
||||
window.isScrolled = wasScrolled;
|
||||
});
|
||||
block._hasToggleListener = true;
|
||||
}
|
||||
|
|
|
|||
211
js/main.js
211
js/main.js
|
|
@ -2,6 +2,12 @@
|
|||
// Main
|
||||
// ------------------------------------------------
|
||||
|
||||
// Sync highlight.js theme with the actual Gradio theme
|
||||
var defined_hljs_css = document.body.classList.contains("dark") ? "file/css/highlightjs/github-dark.min.css" : "file/css/highlightjs/github.min.css";
|
||||
if (document.getElementById("highlight-css").getAttribute("href") !== defined_hljs_css) {
|
||||
document.getElementById("highlight-css").setAttribute("href", defined_hljs_css);
|
||||
}
|
||||
|
||||
let main_parent = document.getElementById("chat-tab").parentNode;
|
||||
let extensions = document.getElementById("extensions");
|
||||
|
||||
|
|
@ -145,10 +151,13 @@ targetElement.classList.add("pretty_scrollbar");
|
|||
targetElement.classList.add("chat-parent");
|
||||
window.isScrolled = false;
|
||||
let scrollTimeout;
|
||||
let lastScrollTop = 0;
|
||||
let lastScrollHeight = 0;
|
||||
let lastClientHeight = 0;
|
||||
|
||||
targetElement.addEventListener("scroll", function() {
|
||||
let diff = targetElement.scrollHeight - targetElement.clientHeight;
|
||||
let isAtBottomNow = Math.abs(targetElement.scrollTop - diff) <= 10 || diff == 0;
|
||||
let isAtBottomNow = Math.abs(targetElement.scrollTop - diff) <= 10 || diff <= 0;
|
||||
|
||||
// Add scrolling class to disable hover effects
|
||||
if (window.isScrolled || !isAtBottomNow) {
|
||||
|
|
@ -157,9 +166,12 @@ targetElement.addEventListener("scroll", function() {
|
|||
|
||||
if(isAtBottomNow) {
|
||||
window.isScrolled = false;
|
||||
} else {
|
||||
} else if (targetElement.scrollTop < lastScrollTop && targetElement.scrollHeight >= lastScrollHeight && targetElement.clientHeight <= lastClientHeight) {
|
||||
window.isScrolled = true;
|
||||
}
|
||||
lastScrollTop = targetElement.scrollTop;
|
||||
lastScrollHeight = targetElement.scrollHeight;
|
||||
lastClientHeight = targetElement.clientHeight;
|
||||
|
||||
// Clear previous timeout and set new one
|
||||
clearTimeout(scrollTimeout);
|
||||
|
|
@ -170,61 +182,28 @@ targetElement.addEventListener("scroll", function() {
|
|||
});
|
||||
|
||||
// Create a MutationObserver instance
|
||||
const observer = new MutationObserver(function(mutations) {
|
||||
// Check if this is just the scrolling class being toggled
|
||||
const isScrollingClassOnly = mutations.every(mutation =>
|
||||
mutation.type === "attributes" &&
|
||||
mutation.attributeName === "class" &&
|
||||
mutation.target === targetElement
|
||||
);
|
||||
|
||||
const observer = new MutationObserver(function() {
|
||||
if (targetElement.classList.contains("_generating")) {
|
||||
typing.parentNode.classList.add("visible-dots");
|
||||
document.getElementById("stop").style.display = "flex";
|
||||
document.getElementById("Generate").style.display = "none";
|
||||
// If the user is near the bottom, ensure auto-scroll is enabled
|
||||
// for the new reply. This catches cases where isScrolled was
|
||||
// incorrectly set to true by layout shifts during page load, etc.
|
||||
const diff = targetElement.scrollHeight - targetElement.clientHeight;
|
||||
if (Math.abs(targetElement.scrollTop - diff) <= 10 || diff <= 0) {
|
||||
window.isScrolled = false;
|
||||
}
|
||||
} else {
|
||||
typing.parentNode.classList.remove("visible-dots");
|
||||
document.getElementById("stop").style.display = "none";
|
||||
document.getElementById("Generate").style.display = "flex";
|
||||
}
|
||||
|
||||
doSyntaxHighlighting();
|
||||
|
||||
if (!window.isScrolled && !isScrollingClassOnly) {
|
||||
const maxScroll = targetElement.scrollHeight - targetElement.clientHeight;
|
||||
if (maxScroll > 0 && targetElement.scrollTop < maxScroll - 1) {
|
||||
targetElement.scrollTop = maxScroll;
|
||||
}
|
||||
}
|
||||
|
||||
const chatElement = document.getElementById("chat");
|
||||
if (chatElement && chatElement.getAttribute("data-mode") === "instruct") {
|
||||
const messagesContainer = chatElement.querySelector(".messages");
|
||||
const lastChild = messagesContainer?.lastElementChild;
|
||||
const prevSibling = lastChild?.previousElementSibling;
|
||||
if (lastChild && prevSibling) {
|
||||
// Add padding to the messages container to create room for the last message.
|
||||
// The purpose of this is to avoid constant scrolling during streaming in
|
||||
// instruct mode.
|
||||
let bufferHeight = Math.max(0, Math.max(window.innerHeight - 128 - 84, window.innerHeight - prevSibling.offsetHeight - 84) - lastChild.offsetHeight);
|
||||
|
||||
// Subtract header height when screen width is <= 924px
|
||||
if (window.innerWidth <= 924) {
|
||||
bufferHeight = Math.max(0, bufferHeight - 32);
|
||||
}
|
||||
|
||||
messagesContainer.style.paddingBottom = `${bufferHeight}px`;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Configure the observer to watch for changes in the subtree and attributes
|
||||
// Only watch for attribute changes on targetElement (e.g. _generating class)
|
||||
const config = {
|
||||
childList: true,
|
||||
subtree: true,
|
||||
characterData: true,
|
||||
attributeOldValue: true,
|
||||
characterDataOldValue: true
|
||||
attributes: true
|
||||
};
|
||||
|
||||
// Start observing the target element
|
||||
|
|
@ -243,64 +222,76 @@ function isElementVisibleOnScreen(element) {
|
|||
);
|
||||
}
|
||||
|
||||
function doSyntaxHighlighting() {
|
||||
window.doSyntaxHighlighting = function() {
|
||||
const messageBodies = document.getElementById("chat").querySelectorAll(".message-body");
|
||||
|
||||
if (messageBodies.length > 0) {
|
||||
observer.disconnect();
|
||||
let hasSeenVisible = false;
|
||||
|
||||
try {
|
||||
let hasSeenVisible = false;
|
||||
// Go from last message to first
|
||||
for (let i = messageBodies.length - 1; i >= 0; i--) {
|
||||
const messageBody = messageBodies[i];
|
||||
|
||||
// Go from last message to first
|
||||
for (let i = messageBodies.length - 1; i >= 0; i--) {
|
||||
const messageBody = messageBodies[i];
|
||||
if (isElementVisibleOnScreen(messageBody)) {
|
||||
hasSeenVisible = true;
|
||||
|
||||
if (isElementVisibleOnScreen(messageBody)) {
|
||||
hasSeenVisible = true;
|
||||
// Handle both code and math in a single pass through each message
|
||||
const codeBlocks = messageBody.querySelectorAll("pre code:not([data-highlighted])");
|
||||
codeBlocks.forEach((codeBlock) => {
|
||||
hljs.highlightElement(codeBlock);
|
||||
codeBlock.setAttribute("data-highlighted", "true");
|
||||
codeBlock.classList.add("pretty_scrollbar");
|
||||
});
|
||||
|
||||
// Handle both code and math in a single pass through each message
|
||||
const codeBlocks = messageBody.querySelectorAll("pre code:not([data-highlighted])");
|
||||
codeBlocks.forEach((codeBlock) => {
|
||||
hljs.highlightElement(codeBlock);
|
||||
codeBlock.setAttribute("data-highlighted", "true");
|
||||
codeBlock.classList.add("pretty_scrollbar");
|
||||
});
|
||||
|
||||
// Only render math in visible elements
|
||||
const mathContainers = messageBody.querySelectorAll("p, span, li, td, th, h1, h2, h3, h4, h5, h6, blockquote, figcaption, caption, dd, dt");
|
||||
mathContainers.forEach(container => {
|
||||
if (isElementVisibleOnScreen(container)) {
|
||||
renderMathInElement(container, {
|
||||
delimiters: [
|
||||
{ left: "$$", right: "$$", display: true },
|
||||
{ left: "$", right: "$", display: false },
|
||||
{ left: "\\(", right: "\\)", display: false },
|
||||
{ left: "\\[", right: "\\]", display: true },
|
||||
],
|
||||
});
|
||||
}
|
||||
});
|
||||
} else if (hasSeenVisible) {
|
||||
// We've seen visible messages but this one is not visible
|
||||
// Since we're going from last to first, we can break
|
||||
break;
|
||||
}
|
||||
// Only render math in visible elements
|
||||
const mathContainers = messageBody.querySelectorAll("p, span, li, td, th, h1, h2, h3, h4, h5, h6, blockquote, figcaption, caption, dd, dt");
|
||||
mathContainers.forEach(container => {
|
||||
if (isElementVisibleOnScreen(container)) {
|
||||
renderMathInElement(container, {
|
||||
delimiters: [
|
||||
{ left: "$$", right: "$$", display: true },
|
||||
{ left: "$", right: "$", display: false },
|
||||
{ left: "\\(", right: "\\)", display: false },
|
||||
{ left: "\\[", right: "\\]", display: true },
|
||||
],
|
||||
});
|
||||
}
|
||||
});
|
||||
} else if (hasSeenVisible) {
|
||||
// We've seen visible messages but this one is not visible
|
||||
// Since we're going from last to first, we can break
|
||||
break;
|
||||
}
|
||||
} finally {
|
||||
observer.observe(targetElement, config);
|
||||
}
|
||||
}
|
||||
}
|
||||
const doSyntaxHighlighting = window.doSyntaxHighlighting;
|
||||
|
||||
//------------------------------------------------
|
||||
// Add some scrollbars
|
||||
//------------------------------------------------
|
||||
const textareaElements = document.querySelectorAll(".add_scrollbar textarea");
|
||||
for(i = 0; i < textareaElements.length; i++) {
|
||||
textareaElements[i].classList.remove("scroll-hide");
|
||||
textareaElements[i].classList.add("pretty_scrollbar");
|
||||
textareaElements[i].style.resize = "none";
|
||||
const scrollbarElements = document.querySelectorAll(".add_scrollbar textarea, .add_scrollbar .drag-drop-list");
|
||||
for(i = 0; i < scrollbarElements.length; i++) {
|
||||
scrollbarElements[i].classList.remove("scroll-hide");
|
||||
scrollbarElements[i].classList.add("pretty_scrollbar");
|
||||
scrollbarElements[i].style.resize = "none";
|
||||
}
|
||||
|
||||
|
||||
//------------------------------------------------
|
||||
// Tools: inject "Refresh list" link into the label
|
||||
//------------------------------------------------
|
||||
const toolsTitle = document.querySelector("#tools-group > [data-testid='block-info']");
|
||||
const toolsInfo = toolsTitle ? toolsTitle.nextElementSibling : null;
|
||||
if (toolsInfo) {
|
||||
const refreshLink = document.createElement("span");
|
||||
refreshLink.textContent = " [Refresh list]";
|
||||
refreshLink.className = "tools-refresh-link";
|
||||
refreshLink.addEventListener("click", function(e) {
|
||||
e.preventDefault();
|
||||
document.querySelector("#tools-refresh-btn").click();
|
||||
});
|
||||
toolsInfo.appendChild(refreshLink);
|
||||
}
|
||||
|
||||
//------------------------------------------------
|
||||
|
|
@ -561,6 +552,38 @@ document.querySelectorAll(".focus-on-chat-input").forEach(element => {
|
|||
});
|
||||
});
|
||||
|
||||
//------------------------------------------------
|
||||
// "New chat" hover menu with incognito option
|
||||
//------------------------------------------------
|
||||
|
||||
(function() {
|
||||
const newChatBtn = document.getElementById("new-chat-btn");
|
||||
|
||||
const wrapper = document.createElement("div");
|
||||
wrapper.id = "new-chat-wrapper";
|
||||
newChatBtn.replaceWith(wrapper);
|
||||
wrapper.appendChild(newChatBtn);
|
||||
|
||||
const arrow = document.createElement("span");
|
||||
arrow.className = "new-chat-arrow";
|
||||
arrow.textContent = "\u25BE";
|
||||
|
||||
const menu = document.createElement("div");
|
||||
menu.className = "new-chat-menu";
|
||||
const option = document.createElement("div");
|
||||
option.className = "new-chat-menu-item";
|
||||
option.textContent = "Incognito chat";
|
||||
menu.appendChild(option);
|
||||
|
||||
arrow.appendChild(menu);
|
||||
wrapper.appendChild(arrow);
|
||||
|
||||
option.addEventListener("click", function(e) {
|
||||
e.stopPropagation();
|
||||
document.querySelector("#incognito-chat-btn").click();
|
||||
});
|
||||
})();
|
||||
|
||||
//------------------------------------------------
|
||||
// Fix a border around the "past chats" menu
|
||||
//------------------------------------------------
|
||||
|
|
@ -1090,15 +1113,13 @@ document.fonts.addEventListener("loadingdone", (event) => {
|
|||
const currentHeight = chatInputRow.offsetHeight;
|
||||
const heightDifference = currentHeight - originalHeight;
|
||||
chatParent.style.marginBottom = `${originalMarginBottom + heightDifference}px`;
|
||||
if (!window.isScrolled) {
|
||||
chatParent.scrollTop = chatParent.scrollHeight - chatParent.clientHeight;
|
||||
}
|
||||
}
|
||||
|
||||
// Watch for changes that might affect height
|
||||
const observer = new MutationObserver(updateMargin);
|
||||
observer.observe(chatInputRow, {
|
||||
childList: true,
|
||||
subtree: true,
|
||||
attributes: true
|
||||
});
|
||||
// Watch for size changes that affect height
|
||||
new ResizeObserver(updateMargin).observe(chatInputRow);
|
||||
|
||||
// Also listen for window resize
|
||||
window.addEventListener("resize", updateMargin);
|
||||
|
|
|
|||
605
modules/chat.py
605
modules/chat.py
|
|
@ -6,12 +6,13 @@ import json
|
|||
import pprint
|
||||
import re
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import markupsafe
|
||||
import yaml
|
||||
from jinja2.ext import loopcontrols
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
|
|
@ -23,10 +24,12 @@ from modules.extensions import apply_extensions
|
|||
from modules.html_generator import (
|
||||
chat_html_wrapper,
|
||||
convert_to_markdown,
|
||||
extract_thinking_block,
|
||||
make_thumbnail
|
||||
)
|
||||
from modules.image_utils import open_image_safely
|
||||
from modules.logging_colors import logger
|
||||
from modules.reasoning import THINKING_FORMATS
|
||||
from modules.text_generation import (
|
||||
generate_reply,
|
||||
get_encoded_length,
|
||||
|
|
@ -41,6 +44,8 @@ from modules.utils import (
|
|||
)
|
||||
from modules.web_search import add_web_search_attachments
|
||||
|
||||
_history_file_lock = threading.Lock()
|
||||
|
||||
|
||||
def strftime_now(format):
|
||||
return datetime.now().strftime(format)
|
||||
|
|
@ -75,8 +80,22 @@ jinja_env = ImmutableSandboxedEnvironment(
|
|||
lstrip_blocks=True,
|
||||
extensions=[loopcontrols]
|
||||
)
|
||||
|
||||
|
||||
def custom_tojson(value, indent=None, ensure_ascii=True):
|
||||
return markupsafe.Markup(json.dumps(value, indent=indent, ensure_ascii=ensure_ascii))
|
||||
|
||||
|
||||
jinja_env.filters["tojson"] = custom_tojson
|
||||
jinja_env.globals["strftime_now"] = strftime_now
|
||||
|
||||
|
||||
def _raise_exception(message):
|
||||
raise ValueError(message)
|
||||
|
||||
|
||||
jinja_env.globals["raise_exception"] = _raise_exception
|
||||
|
||||
_template_cache = {}
|
||||
|
||||
|
||||
|
|
@ -150,6 +169,49 @@ def _deserialize_tool_call_arguments(tool_calls):
|
|||
return result
|
||||
|
||||
|
||||
def _expand_tool_sequence(tool_seq):
|
||||
"""Expand a tool_sequence list into API messages.
|
||||
|
||||
Returns a list of dicts (role: assistant with tool_calls, or role: tool).
|
||||
If any tool_call IDs are missing a matching tool result, a synthetic
|
||||
empty result is inserted so the prompt is never malformed.
|
||||
"""
|
||||
messages = []
|
||||
expected_ids = []
|
||||
seen_ids = set()
|
||||
|
||||
for item in tool_seq:
|
||||
if 'tool_calls' in item:
|
||||
deserialized = _deserialize_tool_call_arguments(item['tool_calls'])
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": item.get('content', ''),
|
||||
"tool_calls": deserialized
|
||||
})
|
||||
for tc in item['tool_calls']:
|
||||
tc_id = tc.get('id', '')
|
||||
if tc_id:
|
||||
expected_ids.append(tc_id)
|
||||
elif item.get('role') == 'tool':
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": item['content'],
|
||||
"tool_call_id": item.get('tool_call_id', '')
|
||||
})
|
||||
seen_ids.add(item.get('tool_call_id', ''))
|
||||
|
||||
# Fill in synthetic results for any orphaned tool call IDs
|
||||
for tc_id in expected_ids:
|
||||
if tc_id not in seen_ids:
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": "",
|
||||
"tool_call_id": tc_id
|
||||
})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def generate_chat_prompt(user_input, state, **kwargs):
|
||||
impersonate = kwargs.get('impersonate', False)
|
||||
_continue = kwargs.get('_continue', False)
|
||||
|
|
@ -185,6 +247,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
name1=state['name1'],
|
||||
name2=state['name2'],
|
||||
user_bio=replace_character_names(state['user_bio'], state['name1'], state['name2']),
|
||||
tools=state['tools'] if 'tools' in state else None,
|
||||
)
|
||||
|
||||
messages = []
|
||||
|
|
@ -296,7 +359,18 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
if entry_meta.get('tool_calls') and messages[insert_pos].get('role') == 'assistant':
|
||||
messages[insert_pos]['tool_calls'] = _deserialize_tool_call_arguments(entry_meta['tool_calls'])
|
||||
|
||||
if user_msg not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
||||
# Expand tool_sequence from metadata (inserted AFTER assistant so that
|
||||
# the final order is: user → tool_calls → tool_results → final_answer)
|
||||
meta_key = f"assistant_{row_idx}"
|
||||
tool_seq = metadata.get(meta_key, {}).get('tool_sequence', [])
|
||||
if tool_seq:
|
||||
for msg in reversed(_expand_tool_sequence(tool_seq)):
|
||||
messages.insert(insert_pos, msg)
|
||||
|
||||
if entry_meta.get('role') == 'system':
|
||||
if user_msg:
|
||||
messages.insert(insert_pos, {"role": "system", "content": user_msg})
|
||||
elif user_msg not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
||||
# Check for user message attachments in metadata
|
||||
user_key = f"user_{row_idx}"
|
||||
enhanced_user_msg = user_msg
|
||||
|
|
@ -365,6 +439,12 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
|
||||
# Expand tool_sequence for the current entry (excluded from the
|
||||
# history loop during regenerate — needed so the model sees prior
|
||||
# tool calls and results when re-generating the final answer).
|
||||
current_tool_seq = metadata.get(f"assistant_{len(history)}", {}).get('tool_sequence', [])
|
||||
messages.extend(_expand_tool_sequence(current_tool_seq))
|
||||
|
||||
if impersonate and state['mode'] != 'chat-instruct':
|
||||
messages.append({"role": "user", "content": "fake user message replace me"})
|
||||
|
||||
|
|
@ -810,6 +890,8 @@ def generate_search_query(user_message, state):
|
|||
query = query.rsplit("</think>", 1)[1]
|
||||
elif "<|start|>assistant<|channel|>final<|message|>" in query:
|
||||
query = query.rsplit("<|start|>assistant<|channel|>final<|message|>", 1)[1]
|
||||
elif "<|channel|>final<|message|>" in query:
|
||||
query = query.rsplit("<|channel|>final<|message|>", 1)[1]
|
||||
elif "</seed:think>" in query:
|
||||
query = query.rsplit("</seed:think>", 1)[1]
|
||||
|
||||
|
|
@ -884,7 +966,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
}
|
||||
else:
|
||||
text, visible_text = output['internal'][-1][0], output['visible'][-1][0]
|
||||
if regenerate:
|
||||
if regenerate and not state.get('_tool_turn'):
|
||||
row_idx = len(output['internal']) - 1
|
||||
|
||||
# Store the old response as a version before regenerating
|
||||
|
|
@ -947,10 +1029,33 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
# Add timestamp for assistant's response at the start of generation
|
||||
update_message_metadata(output['metadata'], "assistant", row_idx, timestamp=get_current_timestamp(), model_name=shared.model_name)
|
||||
|
||||
# Detect if the template appended a thinking start tag to the prompt
|
||||
thinking_prefix = None
|
||||
if not _continue:
|
||||
stripped_prompt = prompt.rstrip('\n')
|
||||
for start_tag, end_tag, content_tag in THINKING_FORMATS:
|
||||
if start_tag is not None and stripped_prompt.endswith(start_tag):
|
||||
thinking_prefix = start_tag
|
||||
break
|
||||
|
||||
# When tools are active, buffer streaming output during potential tool
|
||||
# call generation to prevent raw markup from leaking into the display.
|
||||
_check_tool_markers = bool(state.get('tools'))
|
||||
_last_visible_before_tool_buffer = None
|
||||
if _check_tool_markers:
|
||||
from modules.tool_parsing import streaming_tool_buffer_check, detect_tool_call_format
|
||||
_tool_names = [t['function']['name'] for t in state['tools'] if 'function' in t and 'name' in t['function']]
|
||||
_template_str = state.get('instruction_template_str', '') if state.get('mode') == 'instruct' else state.get('chat_template_str', '')
|
||||
_, _streaming_markers, _check_bare_names = detect_tool_call_format(_template_str)
|
||||
|
||||
# Generate
|
||||
reply = None
|
||||
for j, reply in enumerate(generate_reply(prompt, state, stopping_strings=stopping_strings, is_chat=True, for_ui=for_ui)):
|
||||
|
||||
# Prepend thinking tag if the template appended it to the prompt
|
||||
if thinking_prefix:
|
||||
reply = thinking_prefix + reply
|
||||
|
||||
# Extract the reply
|
||||
if state['mode'] in ['chat', 'chat-instruct']:
|
||||
if not _continue:
|
||||
|
|
@ -982,7 +1087,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')]
|
||||
|
||||
# Keep version metadata in sync during streaming (for regeneration)
|
||||
if regenerate:
|
||||
if regenerate and not state.get('_tool_turn'):
|
||||
row_idx = len(output['internal']) - 1
|
||||
key = f"assistant_{row_idx}"
|
||||
current_idx = output['metadata'][key]['current_version_index']
|
||||
|
|
@ -992,25 +1097,35 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
})
|
||||
|
||||
if is_stream:
|
||||
if _check_tool_markers:
|
||||
if streaming_tool_buffer_check(output['internal'][-1][1], markers=_streaming_markers, tool_names=_tool_names, check_bare_names=_check_bare_names):
|
||||
continue
|
||||
_last_visible_before_tool_buffer = output['visible'][-1][1]
|
||||
|
||||
yield output
|
||||
|
||||
if _continue:
|
||||
# Reprocess the entire internal text for extensions (like translation)
|
||||
full_internal = output['internal'][-1][1]
|
||||
if state['mode'] in ['chat', 'chat-instruct']:
|
||||
full_visible = re.sub("(<USER>|<user>|{{user}})", state['name1'], full_internal)
|
||||
else:
|
||||
full_visible = full_internal
|
||||
# Reprocess the entire internal text for extensions (like translation).
|
||||
# Skip entirely when the visible text contains <tool_call> markers,
|
||||
# since those only exist in visible (internal is cleared after each tool
|
||||
# execution) and rebuilding from internal would destroy them. Output
|
||||
# extensions also can't handle the raw <tool_call> markup safely.
|
||||
if '<tool_call>' not in output['visible'][-1][1]:
|
||||
full_internal = output['internal'][-1][1]
|
||||
if state['mode'] in ['chat', 'chat-instruct']:
|
||||
full_visible = re.sub("(<USER>|<user>|{{user}})", state['name1'], full_internal)
|
||||
else:
|
||||
full_visible = full_internal
|
||||
|
||||
full_visible = html.escape(full_visible)
|
||||
if not state.get('_skip_output_extensions'):
|
||||
output['visible'][-1][1] = apply_extensions('output', full_visible, state, is_chat=True)
|
||||
full_visible = html.escape(full_visible)
|
||||
if not state.get('_skip_output_extensions'):
|
||||
output['visible'][-1][1] = apply_extensions('output', full_visible, state, is_chat=True)
|
||||
else:
|
||||
if not state.get('_skip_output_extensions'):
|
||||
output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state, is_chat=True)
|
||||
|
||||
# Final sync for version metadata (in case streaming was disabled)
|
||||
if regenerate:
|
||||
if regenerate and not state.get('_tool_turn'):
|
||||
row_idx = len(output['internal']) - 1
|
||||
key = f"assistant_{row_idx}"
|
||||
current_idx = output['metadata'][key]['current_version_index']
|
||||
|
|
@ -1019,6 +1134,13 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
'visible_content': output['visible'][row_idx][1]
|
||||
})
|
||||
|
||||
# When tool markers were detected during streaming, restore the last
|
||||
# visible text from before buffering started so raw markup doesn't flash
|
||||
# in the UI. The internal text is left intact so the caller can still
|
||||
# parse tool calls from it.
|
||||
if is_stream and _check_tool_markers and streaming_tool_buffer_check(output['internal'][-1][1], markers=_streaming_markers, tool_names=_tool_names, check_bare_names=_check_bare_names):
|
||||
output['visible'][-1][1] = _last_visible_before_tool_buffer or ''
|
||||
|
||||
yield output
|
||||
|
||||
|
||||
|
|
@ -1064,7 +1186,11 @@ def character_is_loaded(state, raise_exception=False):
|
|||
|
||||
def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
||||
'''
|
||||
Same as above but returns HTML for the UI
|
||||
Same as above but returns HTML for the UI.
|
||||
When tools are selected, wraps generation in a loop that detects
|
||||
tool calls, executes them, and re-generates until the model stops.
|
||||
All tool output is consolidated into a single visible chat bubble
|
||||
using metadata['assistant_N']['tool_sequence'].
|
||||
'''
|
||||
|
||||
if not character_is_loaded(state):
|
||||
|
|
@ -1079,19 +1205,257 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
|||
send_dummy_message(text, state)
|
||||
send_dummy_reply(state['start_with'], state)
|
||||
|
||||
history = state['history']
|
||||
# On regenerate, clear old tool_sequence metadata so it gets rebuilt.
|
||||
# Save it first so it can be stored per-version below.
|
||||
# This must happen after the start_with logic above, which may remove
|
||||
# and re-add messages, changing which row we operate on.
|
||||
_old_tool_sequence = None
|
||||
if regenerate:
|
||||
history = state['history']
|
||||
meta = history.get('metadata', {})
|
||||
row_idx = len(history['internal']) - 1
|
||||
if row_idx >= 0:
|
||||
_old_tool_sequence = meta.get(f'assistant_{row_idx}', {}).pop('tool_sequence', None)
|
||||
|
||||
# Load tools if any are selected
|
||||
selected = state.get('selected_tools', [])
|
||||
parse_tool_call = None
|
||||
_tool_parsers = None
|
||||
if selected:
|
||||
from modules.tool_use import load_tools, execute_tool
|
||||
from modules.tool_parsing import parse_tool_call, get_tool_call_id, detect_tool_call_format
|
||||
|
||||
if selected:
|
||||
tool_defs, tool_executors = load_tools(selected)
|
||||
state['tools'] = tool_defs
|
||||
tool_func_names = [t['function']['name'] for t in tool_defs]
|
||||
_template_str = state.get('instruction_template_str', '') if state.get('mode') == 'instruct' else state.get('chat_template_str', '')
|
||||
_tool_parsers, _, _ = detect_tool_call_format(_template_str)
|
||||
else:
|
||||
tool_func_names = None
|
||||
|
||||
visible_prefix = [] # Accumulated tool call summaries + results
|
||||
last_save_time = time.monotonic()
|
||||
save_interval = 8
|
||||
for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True, for_ui=True)):
|
||||
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'], last_message_only=(i > 0)), history
|
||||
if i == 0:
|
||||
time.sleep(0.125) # We need this to make sure the first update goes through
|
||||
_tool_turn = 0
|
||||
while True:
|
||||
history = state['history']
|
||||
|
||||
current_time = time.monotonic()
|
||||
# Save on first iteration or if save_interval seconds have passed
|
||||
if i == 0 or (current_time - last_save_time) >= save_interval:
|
||||
# Turn 0: use original flags; turns 2+: regenerate into the same entry.
|
||||
# _tool_turn tells chatbot_wrapper to skip version creation/sync so
|
||||
# that intermediate tool-loop regenerations don't pollute swipe history.
|
||||
if _tool_turn > 0:
|
||||
state['_tool_turn'] = True
|
||||
state['_skip_output_extensions'] = True
|
||||
|
||||
regen = regenerate if _tool_turn == 0 else True
|
||||
cont = _continue if _tool_turn == 0 else False
|
||||
cur_text = text if _tool_turn == 0 else ''
|
||||
|
||||
for i, history in enumerate(generate_chat_reply(cur_text, state, regen, cont, loading_message=True, for_ui=True)):
|
||||
# Prepend accumulated tool output to visible reply for display.
|
||||
# Save and restore the original to prevent the markers from leaking
|
||||
# back into chatbot_wrapper's shared output object, which would cause
|
||||
# duplication on the next yield.
|
||||
_original_visible = history['visible'][-1][1] if visible_prefix else None
|
||||
if visible_prefix:
|
||||
history['visible'][-1][1] = '\n\n'.join(visible_prefix + [_original_visible])
|
||||
|
||||
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'], last_message_only=(i > 0)), history
|
||||
|
||||
if visible_prefix:
|
||||
history['visible'][-1][1] = _original_visible
|
||||
|
||||
if i == 0:
|
||||
# Save old tool_sequence into version 0 (created by chatbot_wrapper
|
||||
# on the first yield). Only needed on the first regeneration when
|
||||
# versions didn't previously exist.
|
||||
if _old_tool_sequence is not None and _tool_turn == 0:
|
||||
_ri = len(history['internal']) - 1
|
||||
_versions = history.get('metadata', {}).get(f'assistant_{_ri}', {}).get('versions', [])
|
||||
if _versions and 'tool_sequence' not in _versions[0]:
|
||||
_versions[0]['tool_sequence'] = _old_tool_sequence
|
||||
_old_tool_sequence = None
|
||||
|
||||
time.sleep(0.125)
|
||||
|
||||
current_time = time.monotonic()
|
||||
if i == 0 or (current_time - last_save_time) >= save_interval:
|
||||
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
|
||||
last_save_time = current_time
|
||||
|
||||
# Early stop on tool call detection
|
||||
if tool_func_names and parse_tool_call(history['internal'][-1][1], tool_func_names, parsers=_tool_parsers):
|
||||
break
|
||||
|
||||
# Save the model's visible output before re-applying visible_prefix,
|
||||
# so we can extract thinking content from just this turn's output.
|
||||
_model_visible = history['visible'][-1][1]
|
||||
|
||||
# Recover visible_prefix from existing visible text (e.g. on Continue
|
||||
# after a previous session had tool calls). Extract all <tool_call>
|
||||
# blocks and any text between them (thinking blocks, intermediate text).
|
||||
if tool_func_names and not visible_prefix and _model_visible:
|
||||
tc_matches = list(re.finditer(r'<tool_call>.*?</tool_call>', _model_visible, re.DOTALL))
|
||||
if tc_matches:
|
||||
prefix_end = tc_matches[-1].end()
|
||||
prefix = _model_visible[:prefix_end].strip()
|
||||
if prefix:
|
||||
visible_prefix = [prefix]
|
||||
_model_visible = _model_visible[prefix_end:].strip()
|
||||
|
||||
# Re-apply visible prefix to the final state after streaming completes.
|
||||
# This is safe because we're no longer sharing the object with chatbot_wrapper.
|
||||
if visible_prefix:
|
||||
history['visible'][-1][1] = '\n\n'.join(visible_prefix + [_model_visible])
|
||||
|
||||
if tool_func_names:
|
||||
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
|
||||
last_save_time = current_time
|
||||
|
||||
# Check for tool calls
|
||||
if not tool_func_names or shared.stop_everything:
|
||||
break
|
||||
|
||||
answer = history['internal'][-1][1]
|
||||
parsed_calls, content_prefix = parse_tool_call(answer, tool_func_names, return_prefix=True, parsers=_tool_parsers) if answer else (None, '')
|
||||
|
||||
if not parsed_calls:
|
||||
break # No tool calls — done
|
||||
|
||||
# --- Process tool calls ---
|
||||
row_idx = len(history['internal']) - 1
|
||||
meta = history.get('metadata', {})
|
||||
seq = meta.setdefault(f'assistant_{row_idx}', {}).setdefault('tool_sequence', [])
|
||||
|
||||
def _render():
|
||||
return chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
|
||||
|
||||
# Serialize tool calls and build display headers in one pass
|
||||
serialized = []
|
||||
tc_headers = []
|
||||
for tc in parsed_calls:
|
||||
tc['id'] = get_tool_call_id()
|
||||
fn_name = tc['function']['name']
|
||||
fn_args = tc['function'].get('arguments', {})
|
||||
|
||||
serialized.append({
|
||||
'id': tc['id'],
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': fn_name,
|
||||
'arguments': json.dumps(fn_args) if isinstance(fn_args, dict) else fn_args
|
||||
}
|
||||
})
|
||||
|
||||
if isinstance(fn_args, dict) and fn_args:
|
||||
args_summary = ', '.join(f'{k}={json.dumps(v, ensure_ascii=False)}' for k, v in fn_args.items())
|
||||
elif isinstance(fn_args, dict):
|
||||
args_summary = ''
|
||||
else:
|
||||
args_summary = str(fn_args)
|
||||
|
||||
tc_headers.append(f'{fn_name}({args_summary})')
|
||||
|
||||
seq_entry = {'tool_calls': serialized}
|
||||
if content_prefix.strip():
|
||||
# Strip GPT-OSS channel tokens so they don't get double-wrapped
|
||||
# by the template (which adds its own channel markup).
|
||||
clean = content_prefix.strip()
|
||||
if '<|channel|>' in clean and '<|message|>' in clean:
|
||||
inner = clean.split('<|message|>', 1)[1]
|
||||
if '<|end|>' in inner:
|
||||
inner = inner.split('<|end|>', 1)[0]
|
||||
clean = inner.strip()
|
||||
if clean:
|
||||
seq_entry['content'] = clean
|
||||
seq.append(seq_entry)
|
||||
|
||||
# Clear internal (raw tool markup)
|
||||
history['internal'][-1][1] = ''
|
||||
|
||||
# Preserve thinking block and intermediate text from this turn.
|
||||
# content_prefix is the raw text before tool call syntax (returned
|
||||
# by parse_tool_call); HTML-escape it and extract thinking to get
|
||||
# the content the user should see.
|
||||
content_text = html.escape(content_prefix)
|
||||
thinking_content, intermediate = extract_thinking_block(content_text)
|
||||
if thinking_content:
|
||||
visible_prefix.append(f'<think>\n{thinking_content}\n</think>')
|
||||
if intermediate and intermediate.strip():
|
||||
visible_prefix.append(intermediate.strip())
|
||||
|
||||
# Show placeholder accordions with "..." before execution starts
|
||||
# (tool calls may be slow, e.g. web search).
|
||||
pending_placeholders = [f'<tool_call>{h}\n...\n</tool_call>' for h in tc_headers]
|
||||
history['visible'][-1][1] = '\n\n'.join(visible_prefix + pending_placeholders)
|
||||
yield _render(), history
|
||||
|
||||
# Execute tools, store results, and replace placeholders with real results
|
||||
for i, tc in enumerate(parsed_calls):
|
||||
# Check for stop request before each tool execution
|
||||
if shared.stop_everything:
|
||||
for j in range(i, len(parsed_calls)):
|
||||
seq.append({'role': 'tool', 'content': 'Tool execution was cancelled by the user.', 'tool_call_id': parsed_calls[j]['id']})
|
||||
pending_placeholders[j] = f'<tool_call>{tc_headers[j]}\nCancelled\n</tool_call>'
|
||||
|
||||
history['visible'][-1][1] = '\n\n'.join(visible_prefix + pending_placeholders)
|
||||
yield _render(), history
|
||||
break
|
||||
|
||||
fn_name = tc['function']['name']
|
||||
fn_args = tc['function'].get('arguments', {})
|
||||
result = execute_tool(fn_name, fn_args, tool_executors)
|
||||
|
||||
seq.append({'role': 'tool', 'content': result, 'tool_call_id': tc['id']})
|
||||
try:
|
||||
pretty_result = json.dumps(json.loads(result), indent=2, ensure_ascii=False)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pretty_result = result
|
||||
|
||||
# Replace the placeholder with the real result
|
||||
pending_placeholders[i] = f'<tool_call>{tc_headers[i]}\n{pretty_result}\n</tool_call>'
|
||||
history['visible'][-1][1] = '\n\n'.join(visible_prefix + pending_placeholders)
|
||||
yield _render(), history
|
||||
|
||||
# Move completed tool calls into visible_prefix for next turns
|
||||
visible_prefix.extend(pending_placeholders)
|
||||
history['visible'][-1][1] = '\n\n'.join(visible_prefix)
|
||||
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
|
||||
|
||||
state['history'] = history
|
||||
_tool_turn += 1
|
||||
|
||||
state.pop('_tool_turn', None)
|
||||
|
||||
# If output extensions were deferred during tool turns, apply them now
|
||||
# to the final model response only (not to tool call markers).
|
||||
if state.pop('_skip_output_extensions', None):
|
||||
_model_visible = apply_extensions('output', _model_visible, state, is_chat=True)
|
||||
if visible_prefix:
|
||||
history['visible'][-1][1] = '\n\n'.join(visible_prefix + [_model_visible])
|
||||
else:
|
||||
history['visible'][-1][1] = _model_visible
|
||||
|
||||
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']), history
|
||||
|
||||
state['history'] = history
|
||||
|
||||
# Sync version metadata so swipes show the full visible (with tool prefix)
|
||||
if visible_prefix and history.get('metadata'):
|
||||
row_idx = len(history['internal']) - 1
|
||||
key = f"assistant_{row_idx}"
|
||||
meta_entry = history['metadata'].get(key, {})
|
||||
if 'versions' in meta_entry and 'current_version_index' in meta_entry:
|
||||
current_idx = meta_entry['current_version_index']
|
||||
if current_idx < len(meta_entry['versions']):
|
||||
version_update = {
|
||||
'content': history['internal'][row_idx][1],
|
||||
'visible_content': history['visible'][row_idx][1]
|
||||
}
|
||||
ts = meta_entry.get('tool_sequence')
|
||||
if ts is not None:
|
||||
version_update['tool_sequence'] = ts
|
||||
meta_entry['versions'][current_idx].update(version_update)
|
||||
|
||||
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
|
||||
|
||||
|
|
@ -1164,7 +1528,7 @@ def redraw_html(history, name1, name2, mode, style, character, reset_cache=False
|
|||
return chat_html_wrapper(history, name1, name2, mode, style, character, reset_cache=reset_cache)
|
||||
|
||||
|
||||
def start_new_chat(state):
|
||||
def start_new_chat(state, unique_id=None):
|
||||
mode = state['mode']
|
||||
# Initialize with empty metadata dictionary
|
||||
history = {'internal': [], 'visible': [], 'metadata': {}}
|
||||
|
|
@ -1178,7 +1542,9 @@ def start_new_chat(state):
|
|||
# Add timestamp for assistant's greeting
|
||||
update_message_metadata(history['metadata'], "assistant", 0, timestamp=get_current_timestamp())
|
||||
|
||||
unique_id = datetime.now().strftime('%Y%m%d-%H-%M-%S')
|
||||
if unique_id is None:
|
||||
unique_id = datetime.now().strftime('%Y%m%d-%H-%M-%S')
|
||||
|
||||
save_history(history, unique_id, state['character_menu'], state['mode'])
|
||||
|
||||
return history
|
||||
|
|
@ -1197,12 +1563,16 @@ def save_history(history, unique_id, character, mode):
|
|||
if shared.args.multi_user:
|
||||
return
|
||||
|
||||
if unique_id and unique_id.startswith('incognito-'):
|
||||
return
|
||||
|
||||
p = get_history_file_path(unique_id, character, mode)
|
||||
if not p.parent.is_dir():
|
||||
p.parent.mkdir(parents=True)
|
||||
|
||||
with open(p, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(history, indent=4, ensure_ascii=False))
|
||||
with _history_file_lock:
|
||||
with open(p, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(history, indent=4, ensure_ascii=False))
|
||||
|
||||
|
||||
def rename_history(old_id, new_id, character, mode):
|
||||
|
|
@ -1333,6 +1703,7 @@ def load_history_after_deletion(state, idx):
|
|||
Loads the latest history for the given character in chat or chat-instruct
|
||||
mode, or the latest instruct history for instruct mode.
|
||||
'''
|
||||
import gradio as gr
|
||||
|
||||
if shared.args.multi_user:
|
||||
return start_new_chat(state)
|
||||
|
|
@ -1351,6 +1722,7 @@ def load_history_after_deletion(state, idx):
|
|||
|
||||
|
||||
def update_character_menu_after_deletion(idx):
|
||||
import gradio as gr
|
||||
characters = utils.get_available_characters()
|
||||
idx = min(int(idx), len(characters) - 1)
|
||||
idx = max(0, idx)
|
||||
|
|
@ -1383,6 +1755,9 @@ def save_last_chat_state(character, mode, unique_id):
|
|||
if shared.args.multi_user:
|
||||
return
|
||||
|
||||
if unique_id and unique_id.startswith('incognito-'):
|
||||
return
|
||||
|
||||
state = load_last_chat_state()
|
||||
key = get_chat_state_key(character, mode)
|
||||
state["last_chats"][key] = unique_id
|
||||
|
|
@ -1565,24 +1940,6 @@ def clear_character_for_ui(state):
|
|||
return state, state['name2'], state['context'], state['greeting'], None
|
||||
|
||||
|
||||
def load_instruction_template(template):
|
||||
if template == 'None':
|
||||
return ''
|
||||
|
||||
for filepath in [shared.user_data_dir / 'instruction-templates' / f'{template}.yaml', shared.user_data_dir / 'instruction-templates' / 'Alpaca.yaml']:
|
||||
if filepath.exists():
|
||||
break
|
||||
else:
|
||||
return ''
|
||||
|
||||
file_contents = open(filepath, 'r', encoding='utf-8').read()
|
||||
data = yaml.safe_load(file_contents)
|
||||
if 'instruction_template' in data:
|
||||
return data['instruction_template']
|
||||
else:
|
||||
return jinja_template_from_old_format(data)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def load_character_memoized(character, name1, name2):
|
||||
return load_character(character, name1, name2)
|
||||
|
|
@ -1590,10 +1947,12 @@ def load_character_memoized(character, name1, name2):
|
|||
|
||||
@functools.cache
|
||||
def load_instruction_template_memoized(template):
|
||||
from modules.models_settings import load_instruction_template
|
||||
return load_instruction_template(template)
|
||||
|
||||
|
||||
def upload_character(file, img_path, tavern=False):
|
||||
import gradio as gr
|
||||
img = open_image_safely(img_path)
|
||||
decoded_file = file if isinstance(file, str) else file.decode('utf-8')
|
||||
try:
|
||||
|
|
@ -1647,6 +2006,7 @@ def upload_tavern_character(img_path, _json):
|
|||
|
||||
|
||||
def check_tavern_character(img_path):
|
||||
import gradio as gr
|
||||
img = open_image_safely(img_path)
|
||||
|
||||
if img is None:
|
||||
|
|
@ -1832,6 +2192,7 @@ def delete_user(name):
|
|||
|
||||
def update_user_menu_after_deletion(idx):
|
||||
"""Update user menu after a user is deleted"""
|
||||
import gradio as gr
|
||||
users = get_available_users()
|
||||
if len(users) == 0:
|
||||
# Create a default user if none exist
|
||||
|
|
@ -1864,93 +2225,13 @@ def handle_user_menu_change(state):
|
|||
|
||||
def handle_save_user_click(name1):
|
||||
"""Handle save user button click"""
|
||||
import gradio as gr
|
||||
return [
|
||||
name1,
|
||||
gr.update(visible=True)
|
||||
]
|
||||
|
||||
|
||||
def jinja_template_from_old_format(params, verbose=False):
|
||||
MASTER_TEMPLATE = """
|
||||
{%- set ns = namespace(found=false) -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'system' -%}
|
||||
{%- set ns.found = true -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if not ns.found -%}
|
||||
{{- '<|PRE-SYSTEM|>' + '<|SYSTEM-MESSAGE|>' + '<|POST-SYSTEM|>' -}}
|
||||
{%- endif %}
|
||||
{%- for message in messages %}
|
||||
{%- if message['role'] == 'system' -%}
|
||||
{{- '<|PRE-SYSTEM|>' + message['content'] + '<|POST-SYSTEM|>' -}}
|
||||
{%- else -%}
|
||||
{%- if message['role'] == 'user' -%}
|
||||
{{-'<|PRE-USER|>' + message['content'] + '<|POST-USER|>'-}}
|
||||
{%- else -%}
|
||||
{{-'<|PRE-ASSISTANT|>' + message['content'] + '<|POST-ASSISTANT|>' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{-'<|PRE-ASSISTANT-GENERATE|>'-}}
|
||||
{%- endif -%}
|
||||
"""
|
||||
|
||||
if 'context' in params and '<|system-message|>' in params['context']:
|
||||
pre_system = params['context'].split('<|system-message|>')[0]
|
||||
post_system = params['context'].split('<|system-message|>')[1]
|
||||
else:
|
||||
pre_system = ''
|
||||
post_system = ''
|
||||
|
||||
pre_user = params['turn_template'].split('<|user-message|>')[0].replace('<|user|>', params['user'])
|
||||
post_user = params['turn_template'].split('<|user-message|>')[1].split('<|bot|>')[0]
|
||||
|
||||
pre_assistant = '<|bot|>' + params['turn_template'].split('<|bot-message|>')[0].split('<|bot|>')[1]
|
||||
pre_assistant = pre_assistant.replace('<|bot|>', params['bot'])
|
||||
post_assistant = params['turn_template'].split('<|bot-message|>')[1]
|
||||
|
||||
def preprocess(string):
|
||||
return string.replace('\n', '\\n').replace('\'', '\\\'')
|
||||
|
||||
pre_system = preprocess(pre_system)
|
||||
post_system = preprocess(post_system)
|
||||
pre_user = preprocess(pre_user)
|
||||
post_user = preprocess(post_user)
|
||||
pre_assistant = preprocess(pre_assistant)
|
||||
post_assistant = preprocess(post_assistant)
|
||||
|
||||
if verbose:
|
||||
print(
|
||||
'\n',
|
||||
repr(pre_system) + '\n',
|
||||
repr(post_system) + '\n',
|
||||
repr(pre_user) + '\n',
|
||||
repr(post_user) + '\n',
|
||||
repr(pre_assistant) + '\n',
|
||||
repr(post_assistant) + '\n',
|
||||
)
|
||||
|
||||
result = MASTER_TEMPLATE
|
||||
if 'system_message' in params:
|
||||
result = result.replace('<|SYSTEM-MESSAGE|>', preprocess(params['system_message']))
|
||||
else:
|
||||
result = result.replace('<|SYSTEM-MESSAGE|>', '')
|
||||
|
||||
result = result.replace('<|PRE-SYSTEM|>', pre_system)
|
||||
result = result.replace('<|POST-SYSTEM|>', post_system)
|
||||
result = result.replace('<|PRE-USER|>', pre_user)
|
||||
result = result.replace('<|POST-USER|>', post_user)
|
||||
result = result.replace('<|PRE-ASSISTANT|>', pre_assistant)
|
||||
result = result.replace('<|PRE-ASSISTANT-GENERATE|>', pre_assistant.rstrip(' '))
|
||||
result = result.replace('<|POST-ASSISTANT|>', post_assistant)
|
||||
|
||||
result = result.strip()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def my_yaml_output(data):
|
||||
'''
|
||||
pyyaml is very inconsistent with multiline strings.
|
||||
|
|
@ -2002,6 +2283,7 @@ def handle_unique_id_select(state):
|
|||
|
||||
|
||||
def handle_start_new_chat_click(state):
|
||||
import gradio as gr
|
||||
history = start_new_chat(state)
|
||||
histories = find_all_histories_with_first_prompts(state)
|
||||
html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
|
||||
|
|
@ -2016,10 +2298,29 @@ def handle_start_new_chat_click(state):
|
|||
return [history, html, past_chats_update]
|
||||
|
||||
|
||||
def handle_start_incognito_chat_click(state):
|
||||
import gradio as gr
|
||||
unique_id = 'incognito-' + datetime.now().strftime('%Y%m%d-%H-%M-%S')
|
||||
history = start_new_chat(state, unique_id=unique_id)
|
||||
html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
|
||||
|
||||
convert_to_markdown.cache_clear()
|
||||
|
||||
histories = find_all_histories_with_first_prompts(state)
|
||||
past_chats_update = gr.update(choices=histories, value=unique_id)
|
||||
|
||||
return [history, html, past_chats_update]
|
||||
|
||||
|
||||
def handle_delete_chat_confirm_click(state):
|
||||
filtered_histories = find_all_histories_with_first_prompts(state)
|
||||
filtered_ids = [h[1] for h in filtered_histories]
|
||||
index = str(filtered_ids.index(state['unique_id']))
|
||||
|
||||
if state['unique_id'] not in filtered_ids:
|
||||
# Incognito or unknown chat — just load the most recent saved chat
|
||||
index = '0'
|
||||
else:
|
||||
index = str(filtered_ids.index(state['unique_id']))
|
||||
|
||||
delete_history(state['unique_id'], state['character_menu'], state['mode'])
|
||||
history, unique_id = load_history_after_deletion(state, index)
|
||||
|
|
@ -2027,16 +2328,11 @@ def handle_delete_chat_confirm_click(state):
|
|||
|
||||
convert_to_markdown.cache_clear()
|
||||
|
||||
return [
|
||||
history,
|
||||
html,
|
||||
unique_id,
|
||||
gr.update(visible=False),
|
||||
gr.update(visible=True),
|
||||
]
|
||||
return [history, html, unique_id]
|
||||
|
||||
|
||||
def handle_branch_chat_click(state):
|
||||
import gradio as gr
|
||||
branch_from_index = state['branch_index']
|
||||
if branch_from_index == -1:
|
||||
history = state['history']
|
||||
|
|
@ -2048,7 +2344,8 @@ def handle_branch_chat_click(state):
|
|||
if 'metadata' in history:
|
||||
history['metadata'] = {k: v for k, v in history['metadata'].items() if int(k.split('_')[-1]) <= branch_from_index}
|
||||
|
||||
new_unique_id = datetime.now().strftime('%Y%m%d-%H-%M-%S')
|
||||
prefix = 'incognito-' if state['unique_id'] and state['unique_id'].startswith('incognito-') else ''
|
||||
new_unique_id = prefix + datetime.now().strftime('%Y%m%d-%H-%M-%S')
|
||||
save_history(history, new_unique_id, state['character_menu'], state['mode'])
|
||||
|
||||
histories = find_all_histories_with_first_prompts(state)
|
||||
|
|
@ -2086,14 +2383,19 @@ def handle_edit_message_click(state):
|
|||
original_visible = history['visible'][message_index][role_idx]
|
||||
original_timestamp = history['metadata'][key].get('timestamp', get_current_timestamp())
|
||||
|
||||
history['metadata'][key]["versions"] = [{
|
||||
version_entry = {
|
||||
"content": original_content,
|
||||
"visible_content": original_visible,
|
||||
"timestamp": original_timestamp
|
||||
}]
|
||||
}
|
||||
ts = history['metadata'][key].get('tool_sequence')
|
||||
if ts is not None:
|
||||
version_entry['tool_sequence'] = ts
|
||||
history['metadata'][key]["versions"] = [version_entry]
|
||||
|
||||
history['internal'][message_index][role_idx] = apply_extensions('input', new_text, state, is_chat=True)
|
||||
history['visible'][message_index][role_idx] = html.escape(new_text)
|
||||
history['metadata'][key].pop('tool_sequence', None)
|
||||
|
||||
add_message_version(history, role, message_index, is_current=True)
|
||||
|
||||
|
|
@ -2138,6 +2440,14 @@ def handle_navigate_version_click(state):
|
|||
history['internal'][message_index][msg_content_idx] = version_to_load['content']
|
||||
history['visible'][message_index][msg_content_idx] = version_to_load['visible_content']
|
||||
metadata['current_version_index'] = new_idx
|
||||
|
||||
# Restore per-version tool_sequence so follow-up prompts see consistent context
|
||||
version_ts = version_to_load.get('tool_sequence')
|
||||
if version_ts is not None:
|
||||
metadata['tool_sequence'] = version_ts
|
||||
else:
|
||||
metadata.pop('tool_sequence', None)
|
||||
|
||||
update_message_metadata(history['metadata'], role, message_index, timestamp=version_to_load['timestamp'])
|
||||
|
||||
# Redraw and save
|
||||
|
|
@ -2148,6 +2458,7 @@ def handle_navigate_version_click(state):
|
|||
|
||||
|
||||
def handle_rename_chat_click():
|
||||
import gradio as gr
|
||||
return [
|
||||
gr.update(value="My New Chat"),
|
||||
gr.update(visible=True),
|
||||
|
|
@ -2155,6 +2466,14 @@ def handle_rename_chat_click():
|
|||
|
||||
|
||||
def handle_rename_chat_confirm(rename_to, state):
|
||||
import gradio as gr
|
||||
|
||||
if state['unique_id'] and state['unique_id'].startswith('incognito-'):
|
||||
return [
|
||||
gr.update(),
|
||||
gr.update(visible=False),
|
||||
]
|
||||
|
||||
rename_history(state['unique_id'], rename_to, state['character_menu'], state['mode'])
|
||||
histories = find_all_histories_with_first_prompts(state)
|
||||
|
||||
|
|
@ -2165,11 +2484,13 @@ def handle_rename_chat_confirm(rename_to, state):
|
|||
|
||||
|
||||
def handle_search_chat_change(state):
|
||||
import gradio as gr
|
||||
histories = find_all_histories_with_first_prompts(state)
|
||||
return gr.update(choices=histories)
|
||||
|
||||
|
||||
def handle_upload_chat_history(load_chat_history, state):
|
||||
import gradio as gr
|
||||
history = start_new_chat(state)
|
||||
history = load_history_json(load_chat_history, history)
|
||||
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
|
||||
|
|
@ -2192,6 +2513,7 @@ def handle_upload_chat_history(load_chat_history, state):
|
|||
|
||||
|
||||
def handle_character_menu_change(state):
|
||||
import gradio as gr
|
||||
name1, name2, picture, greeting, context = load_character(state['character_menu'], state['name1'], state['name2'])
|
||||
|
||||
state['name1'] = name1
|
||||
|
|
@ -2244,6 +2566,7 @@ def handle_character_picture_change(picture_path):
|
|||
|
||||
|
||||
def handle_mode_change(state):
|
||||
import gradio as gr
|
||||
history, loaded_unique_id = load_latest_history(state)
|
||||
histories = find_all_histories_with_first_prompts(state)
|
||||
|
||||
|
|
@ -2270,6 +2593,7 @@ def handle_mode_change(state):
|
|||
|
||||
|
||||
def handle_save_character_click(name2):
|
||||
import gradio as gr
|
||||
return [
|
||||
name2,
|
||||
gr.update(visible=True)
|
||||
|
|
@ -2277,6 +2601,7 @@ def handle_save_character_click(name2):
|
|||
|
||||
|
||||
def handle_load_template_click(instruction_template):
|
||||
from modules.models_settings import load_instruction_template
|
||||
output = load_instruction_template(instruction_template)
|
||||
return [
|
||||
output,
|
||||
|
|
@ -2285,6 +2610,7 @@ def handle_load_template_click(instruction_template):
|
|||
|
||||
|
||||
def handle_save_template_click(instruction_template_str):
|
||||
import gradio as gr
|
||||
contents = generate_instruction_template_yaml(instruction_template_str)
|
||||
return [
|
||||
"My Template.yaml",
|
||||
|
|
@ -2295,6 +2621,7 @@ def handle_save_template_click(instruction_template_str):
|
|||
|
||||
|
||||
def handle_delete_template_click(template):
|
||||
import gradio as gr
|
||||
return [
|
||||
f"{template}.yaml",
|
||||
str(shared.user_data_dir / 'instruction-templates') + '/',
|
||||
|
|
@ -2310,6 +2637,7 @@ def handle_your_picture_change(picture, state):
|
|||
|
||||
|
||||
def handle_send_instruction_click(state):
|
||||
import gradio as gr
|
||||
state['mode'] = 'instruct'
|
||||
state['history'] = {'internal': [], 'visible': [], 'metadata': {}}
|
||||
|
||||
|
|
@ -2322,6 +2650,7 @@ def handle_send_instruction_click(state):
|
|||
|
||||
|
||||
def handle_send_chat_click(state):
|
||||
import gradio as gr
|
||||
output = generate_chat_prompt("", state, _continue=True)
|
||||
|
||||
if state["show_two_notebook_columns"]:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import math
|
||||
import queue
|
||||
import threading
|
||||
import traceback
|
||||
|
|
@ -9,6 +10,7 @@ import torch
|
|||
from exllamav3 import Cache, Config, Generator, Model, Tokenizer
|
||||
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
|
||||
from exllamav3.generator import Job
|
||||
from exllamav3.generator.filter import Filter
|
||||
from exllamav3.generator.sampler import (
|
||||
CustomSampler,
|
||||
SS_AdaptiveP,
|
||||
|
|
@ -36,6 +38,29 @@ except Exception:
|
|||
traceback.print_exc()
|
||||
|
||||
|
||||
class LogitBiasFilter(Filter):
|
||||
"""Filter subclass that applies a static additive logit bias mask."""
|
||||
|
||||
def __init__(self, tokenizer, logit_bias_dict):
|
||||
super().__init__(tokenizer=tokenizer, trigger_token=None, prefix_str=None, eos_after_completed=False)
|
||||
self.logit_bias_dict = logit_bias_dict
|
||||
self._mask = None
|
||||
|
||||
def reset(self): pass
|
||||
def accept_token(self, token): pass
|
||||
def is_completed(self): return False
|
||||
def use_background_worker(self): return False
|
||||
|
||||
def get_next_logit_mask(self):
|
||||
if self._mask is None:
|
||||
self._mask = torch.zeros((1, self.vocab_size), dtype=self.logits_dtype)
|
||||
for token_id_str, bias in self.logit_bias_dict.items():
|
||||
token_id = int(token_id_str)
|
||||
if 0 <= token_id < self.vocab_size:
|
||||
self._mask[0, token_id] = bias
|
||||
return self._mask
|
||||
|
||||
|
||||
class ConcurrentGenerator:
|
||||
def __init__(self, generator):
|
||||
self.generator = generator
|
||||
|
|
@ -53,7 +78,16 @@ class ConcurrentGenerator:
|
|||
if not self.job_queues:
|
||||
self.has_jobs.clear()
|
||||
continue
|
||||
results = self.generator.iterate()
|
||||
try:
|
||||
results = self.generator.iterate()
|
||||
except Exception:
|
||||
logger.error("Exception in ConcurrentGenerator iterate loop:\n" + traceback.format_exc())
|
||||
for q in self.job_queues.values():
|
||||
q.put(None)
|
||||
self.job_queues.clear()
|
||||
self.generator.clear_queue()
|
||||
self.has_jobs.clear()
|
||||
continue
|
||||
for result in results:
|
||||
job = result["job"]
|
||||
q = self.job_queues.get(job)
|
||||
|
|
@ -89,6 +123,10 @@ class Exllamav3Model:
|
|||
def __init__(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(0)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, path_to_model):
|
||||
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
|
||||
|
|
@ -149,8 +187,21 @@ class Exllamav3Model:
|
|||
load_params['tensor_p'] = True
|
||||
load_params['tp_backend'] = shared.args.tp_backend
|
||||
|
||||
model.load(**load_params)
|
||||
tokenizer = Tokenizer.from_config(config)
|
||||
# Load vision and draft before the main model so autosplit
|
||||
# accounts for their VRAM usage.
|
||||
|
||||
# Load vision model component (ExLlamaV3 native)
|
||||
vision_model = None
|
||||
if "vision_config" in config.config_dict:
|
||||
logger.info("Vision component detected in model config. Attempting to load...")
|
||||
try:
|
||||
vision_model = Model.from_config(config, component="vision")
|
||||
vision_model.load(progressbar=True)
|
||||
logger.info("Vision model loaded successfully.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Vision model loading failed (multimodal disabled): {e}")
|
||||
else:
|
||||
logger.info("No vision component in model config. Skipping multimodal setup.")
|
||||
|
||||
# Initialize draft model for speculative decoding
|
||||
draft_model = None
|
||||
|
|
@ -166,23 +217,8 @@ class Exllamav3Model:
|
|||
logger.warning(f"Draft model not found at {draft_path}, speculative decoding disabled.")
|
||||
else:
|
||||
draft_config = Config.from_directory(str(draft_path))
|
||||
|
||||
# Set context size for draft model with 256-multiple validation
|
||||
if shared.args.ctx_size_draft > 0:
|
||||
draft_max_tokens = shared.args.ctx_size_draft
|
||||
else:
|
||||
draft_max_tokens = shared.args.ctx_size
|
||||
|
||||
# Validate draft model context size is a multiple of 256
|
||||
if draft_max_tokens % 256 != 0:
|
||||
adjusted_draft_tokens = ((draft_max_tokens // 256) + 1) * 256
|
||||
logger.warning(f"Draft model max_num_tokens must be a multiple of 256. Adjusting from {draft_max_tokens} to {adjusted_draft_tokens}")
|
||||
draft_max_tokens = adjusted_draft_tokens
|
||||
|
||||
draft_config.max_seq_len = draft_max_tokens
|
||||
|
||||
draft_model = Model.from_config(draft_config)
|
||||
draft_cache = Cache(draft_model, max_num_tokens=draft_max_tokens, layer_type=layer_type, **cache_kwargs)
|
||||
draft_cache = Cache(draft_model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs)
|
||||
|
||||
draft_load_params = {'progressbar': True}
|
||||
if split:
|
||||
|
|
@ -191,18 +227,9 @@ class Exllamav3Model:
|
|||
draft_model.load(**draft_load_params)
|
||||
logger.info(f"Draft model loaded successfully. Max speculative tokens: {shared.args.draft_max}")
|
||||
|
||||
# Load vision model component (ExLlamaV3 native)
|
||||
vision_model = None
|
||||
if "vision_config" in config.config_dict:
|
||||
logger.info("Vision component detected in model config. Attempting to load...")
|
||||
try:
|
||||
vision_model = Model.from_config(config, component="vision")
|
||||
vision_model.load(progressbar=True)
|
||||
logger.info("Vision model loaded successfully.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Vision model loading failed (multimodal disabled): {e}")
|
||||
else:
|
||||
logger.info("No vision component in model config. Skipping multimodal setup.")
|
||||
# Load main model last
|
||||
model.load(**load_params)
|
||||
tokenizer = Tokenizer.from_config(config)
|
||||
|
||||
generator = Generator(
|
||||
model=model,
|
||||
|
|
@ -385,11 +412,22 @@ class Exllamav3Model:
|
|||
else:
|
||||
max_new_tokens = state['max_new_tokens']
|
||||
|
||||
# Get stop conditions
|
||||
# Use full EOS token list from config (may contain multiple IDs)
|
||||
stop_conditions = []
|
||||
if not state['ban_eos_token']:
|
||||
if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
|
||||
stop_conditions.append(self.tokenizer.eos_token_id)
|
||||
for eos_id in self.config.eos_token_id_list:
|
||||
if eos_id is not None:
|
||||
stop_conditions.append(eos_id)
|
||||
|
||||
# Build filters for logit_bias (OpenAI API)
|
||||
filters = []
|
||||
logit_bias = state.get('logit_bias')
|
||||
if logit_bias:
|
||||
filters.append(LogitBiasFilter(self.tokenizer, logit_bias))
|
||||
|
||||
# Logprobs support (OpenAI API)
|
||||
logprobs = state.get('logprobs', 0) or 0
|
||||
return_top_tokens = logprobs if logprobs > 0 else 0
|
||||
|
||||
seed = state.get('seed', -1)
|
||||
job = Job(
|
||||
|
|
@ -400,11 +438,15 @@ class Exllamav3Model:
|
|||
sampler=sampler,
|
||||
seed=seed if seed >= 0 else None,
|
||||
stop_conditions=stop_conditions if stop_conditions else None,
|
||||
filters=filters if filters else None,
|
||||
return_top_tokens=return_top_tokens,
|
||||
return_probs=return_top_tokens > 0,
|
||||
)
|
||||
|
||||
# Stream generation
|
||||
response_text = ""
|
||||
stop_event = state.get('stop_event')
|
||||
self.last_completion_probabilities = []
|
||||
|
||||
result_queue = self.parallel_generator.submit(job)
|
||||
try:
|
||||
|
|
@ -416,14 +458,41 @@ class Exllamav3Model:
|
|||
except queue.Empty:
|
||||
continue
|
||||
if result is None or result.get("eos"):
|
||||
# Capture logprobs from the final eos result too
|
||||
if result is not None and return_top_tokens > 0:
|
||||
self._capture_logprobs(result)
|
||||
break
|
||||
chunk = result.get("text", "")
|
||||
|
||||
# Capture logprobs from streaming results
|
||||
if return_top_tokens > 0:
|
||||
self._capture_logprobs(result)
|
||||
|
||||
if chunk:
|
||||
response_text += chunk
|
||||
yield response_text
|
||||
finally:
|
||||
self.parallel_generator.cancel(job)
|
||||
|
||||
def _capture_logprobs(self, result):
|
||||
"""Convert ExLlamav3 top-k token data to the shared logprobs format."""
|
||||
top_k_tokens = result.get("top_k_tokens")
|
||||
top_k_probs = result.get("top_k_probs")
|
||||
if top_k_tokens is None or top_k_probs is None:
|
||||
return
|
||||
|
||||
id_to_piece = self.tokenizer.get_id_to_piece_list(True)
|
||||
# top_k_tokens shape: (batch, seq_len, k), top_k_probs same
|
||||
for seq_idx in range(top_k_tokens.shape[1]):
|
||||
entry = {"top_logprobs": []}
|
||||
for k_idx in range(top_k_tokens.shape[2]):
|
||||
token_id = top_k_tokens[0, seq_idx, k_idx].item()
|
||||
prob = top_k_probs[0, seq_idx, k_idx].item()
|
||||
token_str = id_to_piece[token_id] if token_id < len(id_to_piece) else f"<{token_id}>"
|
||||
logprob = math.log(prob) if prob > 0 else float("-inf")
|
||||
entry["top_logprobs"].append({"token": token_str, "logprob": logprob})
|
||||
self.last_completion_probabilities.append(entry)
|
||||
|
||||
def generate(self, prompt, state):
|
||||
output = ""
|
||||
for chunk in self.generate_with_streaming(prompt, state):
|
||||
|
|
|
|||
|
|
@ -201,19 +201,23 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
|
|||
}
|
||||
).to(input_ids.device).float()
|
||||
else:
|
||||
# When processing with labels, handle as a complete sequence
|
||||
# Process in chunks if the number of tokens is large
|
||||
# Labels path: use cache for cross-chunk attention.
|
||||
tokens_to_process = seq_tensor
|
||||
all_logits = None
|
||||
current_len = 0
|
||||
|
||||
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
|
||||
chunk = tokens_to_process[i:i + max_chunk_size]
|
||||
chunk_logits = self.ex_model.forward(
|
||||
input_ids=chunk.view(1, -1),
|
||||
params={
|
||||
"attn_mode": "flash_attn_nc",
|
||||
"attn_mode": "flash_attn",
|
||||
"cache": ex_cache,
|
||||
"past_len": current_len,
|
||||
"batch_shape": (1, self.max_tokens),
|
||||
}
|
||||
).float()
|
||||
current_len += chunk.shape[0]
|
||||
|
||||
if all_logits is None:
|
||||
all_logits = chunk_logits
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@ from functools import partial
|
|||
from inspect import signature
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
|
||||
import modules.shared as shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
|
|
@ -214,6 +212,7 @@ def _apply_custom_js():
|
|||
|
||||
|
||||
def create_extensions_block():
|
||||
import gradio as gr
|
||||
to_display = []
|
||||
for extension, name in iterator():
|
||||
if hasattr(extension, "ui") and not (hasattr(extension, 'params') and extension.params.get('is_tab', False)):
|
||||
|
|
@ -228,6 +227,7 @@ def create_extensions_block():
|
|||
|
||||
|
||||
def create_extensions_tabs():
|
||||
import gradio as gr
|
||||
for extension, name in iterator():
|
||||
if hasattr(extension, "ui") and (hasattr(extension, 'params') and extension.params.get('is_tab', False)):
|
||||
display_name = getattr(extension, 'params', {}).get('display_name', name)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import markdown
|
|||
from PIL import Image, ImageOps
|
||||
|
||||
from modules import shared
|
||||
from modules.reasoning import extract_reasoning
|
||||
from modules.sane_markdown_lists import SaneListExtension
|
||||
from modules.utils import get_available_chat_styles
|
||||
|
||||
|
|
@ -108,69 +109,41 @@ def replace_blockquote(m):
|
|||
return m.group().replace('\n', '\n> ').replace('\\begin{blockquote}', '').replace('\\end{blockquote}', '')
|
||||
|
||||
|
||||
# Thinking block format definitions: (start_tag, end_tag, content_start_tag)
|
||||
# Use None for start_tag to match from beginning (end-only formats should be listed last)
|
||||
THINKING_FORMATS = [
|
||||
('<think>', '</think>', None),
|
||||
('<|channel|>analysis<|message|>', '<|end|>', '<|start|>assistant<|channel|>final<|message|>'),
|
||||
('<seed:think>', '</seed:think>', None),
|
||||
('<|think|>', '<|end|>', '<|content|>'), # Solar Open
|
||||
('Thinking Process:', '</think>', None), # Qwen3.5 verbose thinking outside tags
|
||||
(None, '</think>', None), # End-only variant (e.g., Qwen3-next)
|
||||
]
|
||||
|
||||
|
||||
def extract_thinking_block(string):
|
||||
"""Extract thinking blocks from the beginning of a string."""
|
||||
if not string:
|
||||
return None, string
|
||||
|
||||
for start_tag, end_tag, content_tag in THINKING_FORMATS:
|
||||
end_esc = html.escape(end_tag)
|
||||
content_esc = html.escape(content_tag) if content_tag else None
|
||||
|
||||
if start_tag is None:
|
||||
# End-only format: require end tag, start from beginning
|
||||
end_pos = string.find(end_esc)
|
||||
if end_pos == -1:
|
||||
continue
|
||||
thought_start = 0
|
||||
else:
|
||||
# Normal format: require start tag
|
||||
start_esc = html.escape(start_tag)
|
||||
start_pos = string.find(start_esc)
|
||||
if start_pos == -1:
|
||||
continue
|
||||
thought_start = start_pos + len(start_esc)
|
||||
end_pos = string.find(end_esc, thought_start)
|
||||
|
||||
if end_pos == -1:
|
||||
# End tag missing - check if content tag can serve as fallback
|
||||
if content_esc:
|
||||
content_pos = string.find(content_esc, thought_start)
|
||||
if content_pos != -1:
|
||||
thought_end = content_pos
|
||||
content_start = content_pos + len(content_esc)
|
||||
else:
|
||||
thought_end = len(string)
|
||||
content_start = len(string)
|
||||
else:
|
||||
thought_end = len(string)
|
||||
content_start = len(string)
|
||||
else:
|
||||
thought_end = end_pos
|
||||
if content_esc:
|
||||
content_pos = string.find(content_esc, end_pos)
|
||||
content_start = content_pos + len(content_esc) if content_pos != -1 else end_pos + len(end_esc)
|
||||
else:
|
||||
content_start = end_pos + len(end_esc)
|
||||
|
||||
return string[thought_start:thought_end], string[content_start:]
|
||||
|
||||
return None, string
|
||||
"""Extract thinking blocks from the beginning of an HTML-escaped string."""
|
||||
return extract_reasoning(string, html_escaped=True)
|
||||
|
||||
|
||||
def build_thinking_block(thinking_content, message_id, has_remaining_content):
|
||||
|
||||
def build_tool_call_block(header, body, message_id, index):
|
||||
"""Build HTML for a tool call accordion block."""
|
||||
block_id = f"tool-call-{message_id}-{index}"
|
||||
|
||||
if body == '...':
|
||||
# Pending placeholder — no expandable body, just title with ellipsis
|
||||
return f'''
|
||||
<details class="thinking-block" data-block-id="{block_id}">
|
||||
<summary class="thinking-header">
|
||||
{tool_svg_small}
|
||||
<span class="thinking-title">{html.escape(header)} ...</span>
|
||||
</summary>
|
||||
</details>
|
||||
'''
|
||||
|
||||
# Build a plain <pre> directly to avoid highlight.js auto-detection
|
||||
escaped_body = html.escape(body)
|
||||
return f'''
|
||||
<details class="thinking-block" data-block-id="{block_id}">
|
||||
<summary class="thinking-header">
|
||||
{tool_svg_small}
|
||||
<span class="thinking-title">{html.escape(header)}</span>
|
||||
</summary>
|
||||
<div class="thinking-content pretty_scrollbar"><pre><code class="nohighlight">{escaped_body}</code></pre></div>
|
||||
</details>
|
||||
'''
|
||||
|
||||
|
||||
def build_thinking_block(thinking_content, message_id, has_remaining_content, thinking_index=0):
|
||||
"""Build HTML for a thinking block."""
|
||||
if thinking_content is None:
|
||||
return None
|
||||
|
|
@ -179,7 +152,7 @@ def build_thinking_block(thinking_content, message_id, has_remaining_content):
|
|||
thinking_html = process_markdown_content(thinking_content)
|
||||
|
||||
# Generate unique ID for the thinking block
|
||||
block_id = f"thinking-{message_id}-0"
|
||||
block_id = f"thinking-{message_id}-{thinking_index}"
|
||||
|
||||
# Check if thinking is complete or still in progress
|
||||
is_streaming = not has_remaining_content
|
||||
|
|
@ -344,6 +317,9 @@ def process_markdown_content(string):
|
|||
# Unescape backslashes
|
||||
html_output = html_output.replace('\\\\', '\\')
|
||||
|
||||
# Wrap tables in a scrollable div
|
||||
html_output = html_output.replace('<table>', '<div class="table-wrapper pretty_scrollbar"><table>').replace('</table>', '</table></div>')
|
||||
|
||||
return html_output
|
||||
|
||||
|
||||
|
|
@ -360,24 +336,66 @@ def convert_to_markdown(string, message_id=None):
|
|||
if message_id is None:
|
||||
message_id = "unknown"
|
||||
|
||||
# Extract different components from the string
|
||||
thinking_content, remaining_content = extract_thinking_block(string)
|
||||
# Find tool call blocks by position, then process the text segments
|
||||
# between them using extract_thinking_block (which supports all
|
||||
# THINKING_FORMATS, including end-only variants like Qwen's).
|
||||
tool_call_pattern = re.compile(r'<tool_call>(.*?)\n(.*?)\n</tool_call>', re.DOTALL)
|
||||
tool_calls = list(tool_call_pattern.finditer(string))
|
||||
|
||||
# Build individual HTML blocks
|
||||
blocks = []
|
||||
if not tool_calls:
|
||||
# No tool calls — use original single-pass extraction
|
||||
thinking_content, remaining_content = extract_thinking_block(string)
|
||||
blocks = []
|
||||
thinking_html = build_thinking_block(thinking_content, message_id, bool(remaining_content))
|
||||
if thinking_html:
|
||||
blocks.append(thinking_html)
|
||||
|
||||
# Add thinking block if present
|
||||
thinking_html = build_thinking_block(thinking_content, message_id, bool(remaining_content))
|
||||
if thinking_html:
|
||||
blocks.append(thinking_html)
|
||||
main_html = build_main_content_block(remaining_content)
|
||||
if main_html:
|
||||
blocks.append(main_html)
|
||||
|
||||
# Add main content block
|
||||
main_html = build_main_content_block(remaining_content)
|
||||
if main_html:
|
||||
blocks.append(main_html)
|
||||
return ''.join(blocks)
|
||||
|
||||
# Assemble all blocks into final HTML
|
||||
return ''.join(blocks)
|
||||
# Split string into text segments around tool_call blocks and
|
||||
# run extract_thinking_block on each segment for full format support.
|
||||
html_parts = []
|
||||
last_end = 0
|
||||
tool_idx = 0
|
||||
think_idx = 0
|
||||
|
||||
def process_text_segment(text, is_last_segment):
|
||||
"""Process a text segment between tool_call blocks for thinking content."""
|
||||
nonlocal think_idx
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
while text.strip():
|
||||
thinking_content, remaining = extract_thinking_block(text)
|
||||
if thinking_content is None:
|
||||
break
|
||||
has_remaining = bool(remaining.strip()) or not is_last_segment
|
||||
html_parts.append(build_thinking_block(thinking_content, message_id, has_remaining, think_idx))
|
||||
think_idx += 1
|
||||
text = remaining
|
||||
|
||||
if text.strip():
|
||||
html_parts.append(process_markdown_content(text))
|
||||
|
||||
for tc in tool_calls:
|
||||
# Process text before this tool_call
|
||||
process_text_segment(string[last_end:tc.start()], is_last_segment=False)
|
||||
|
||||
# Add tool call accordion
|
||||
header = tc.group(1).strip()
|
||||
body = tc.group(2).strip()
|
||||
html_parts.append(build_tool_call_block(header, body, message_id, tool_idx))
|
||||
tool_idx += 1
|
||||
last_end = tc.end()
|
||||
|
||||
# Process text after the last tool_call
|
||||
process_text_segment(string[last_end:], is_last_segment=True)
|
||||
|
||||
return ''.join(html_parts)
|
||||
|
||||
|
||||
def convert_to_markdown_wrapped(string, message_id=None, use_cache=True):
|
||||
|
|
@ -435,6 +453,7 @@ branch_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24
|
|||
edit_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="tabler-icon tabler-icon-pencil"><path d="M4 20h4l10.5 -10.5a2.828 2.828 0 1 0 -4 -4l-10.5 10.5v4"></path><path d="M13.5 6.5l4 4"></path></svg>'''
|
||||
info_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="thinking-icon tabler-icon tabler-icon-info-circle"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M12 2a10 10 0 0 1 0 20a10 10 0 0 1 0 -20z" /><path d="M12 16v-4" /><path d="M12 8h.01" /></svg>'''
|
||||
info_svg_small = '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="thinking-icon tabler-icon tabler-icon-info-circle"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M12 2a10 10 0 0 1 0 20a10 10 0 0 1 0 -20z" /><path d="M12 16v-4" /><path d="M12 8h.01" /></svg>'''
|
||||
tool_svg_small = '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="thinking-icon tabler-icon tabler-icon-tool"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M7 10h3v-3l-3.5 -3.5a6 6 0 0 1 8 8l6 6a2 2 0 0 1 -3 3l-6 -6a6 6 0 0 1 -8 -8l3.5 3.5" /></svg>'''
|
||||
attachment_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M21.44 11.05l-9.19 9.19a6 6 0 0 1-8.48-8.48l9.19-9.19a4 4 0 0 1 5.66 5.66l-9.2 9.19a2 2 0 0 1-2.83-2.83l8.49-8.48"></path></svg>'''
|
||||
|
||||
copy_button = f'<button class="footer-button footer-copy-button" title="Copy" onclick="copyToClipboard(this)">{copy_svg}</button>'
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ class LlamaServer:
|
|||
self.process = None
|
||||
self.session = requests.Session()
|
||||
self.vocabulary_size = None
|
||||
self.n_ctx = None
|
||||
self.bos_token = "<s>"
|
||||
self.last_prompt_token_count = 0
|
||||
|
||||
|
|
@ -133,9 +134,20 @@ class LlamaServer:
|
|||
|
||||
payload["samplers"] = filtered_samplers
|
||||
|
||||
logit_bias = []
|
||||
if state['custom_token_bans']:
|
||||
to_ban = [[int(token_id), False] for token_id in state['custom_token_bans'].split(',')]
|
||||
payload["logit_bias"] = to_ban
|
||||
logit_bias.extend([[int(token_id.strip()), False] for token_id in state['custom_token_bans'].split(',') if token_id.strip()])
|
||||
|
||||
if state.get('logit_bias'):
|
||||
for token_id_str, bias in state['logit_bias'].items():
|
||||
logit_bias.append([int(token_id_str), bias])
|
||||
|
||||
if logit_bias:
|
||||
payload["logit_bias"] = logit_bias
|
||||
|
||||
n_probs = state.get('logprobs', 0)
|
||||
if n_probs and n_probs > 0:
|
||||
payload["n_probs"] = n_probs
|
||||
|
||||
return payload
|
||||
|
||||
|
|
@ -215,6 +227,7 @@ class LlamaServer:
|
|||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
|
||||
full_text = ""
|
||||
self.last_completion_probabilities = []
|
||||
|
||||
# Process the streaming response
|
||||
stop_event = state.get('stop_event')
|
||||
|
|
@ -240,6 +253,10 @@ class LlamaServer:
|
|||
full_text += data['content']
|
||||
yield full_text
|
||||
|
||||
# Capture logprobs if present
|
||||
if 'completion_probabilities' in data:
|
||||
self.last_completion_probabilities.extend(data['completion_probabilities'])
|
||||
|
||||
# Check if generation is complete
|
||||
if data.get('stop', False):
|
||||
break
|
||||
|
|
@ -304,12 +321,17 @@ class LlamaServer:
|
|||
self.vocabulary_size = model_info["meta"]["n_vocab"]
|
||||
|
||||
def _get_bos_token(self):
|
||||
"""Get and store the model's BOS token."""
|
||||
"""Get and store the model's BOS token and context size."""
|
||||
url = f"http://127.0.0.1:{self.port}/props"
|
||||
response = self.session.get(url).json()
|
||||
if "bos_token" in response:
|
||||
self.bos_token = response["bos_token"]
|
||||
|
||||
# Get actual n_ctx from the server (important when --fit auto-selects it)
|
||||
n_ctx = response.get("default_generation_settings", {}).get("n_ctx")
|
||||
if n_ctx:
|
||||
self.n_ctx = n_ctx
|
||||
|
||||
def _is_port_available(self, port):
|
||||
"""Check if a port is available for use."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
|
|
@ -349,11 +371,14 @@ class LlamaServer:
|
|||
|
||||
if shared.args.ctx_size > 0:
|
||||
cmd += ["--ctx-size", str(shared.args.ctx_size)]
|
||||
elif shared.args.gpu_layers >= 0:
|
||||
cmd += ["--ctx-size", "8192"]
|
||||
|
||||
if shared.args.gpu_layers >= 0:
|
||||
cmd += ["--gpu-layers", str(shared.args.gpu_layers), "--fit", "off"]
|
||||
else:
|
||||
cmd += ["--fit", "on"]
|
||||
cmd += ["--fit-ctx", "8192"]
|
||||
if shared.args.fit_target:
|
||||
cmd += ["--fit-target", shared.args.fit_target]
|
||||
|
||||
|
|
@ -379,10 +404,6 @@ class LlamaServer:
|
|||
if shared.args.cache_type != "fp16" and shared.args.cache_type in llamacpp_valid_cache_types:
|
||||
cmd += ["--cache-type-k", shared.args.cache_type, "--cache-type-v", shared.args.cache_type]
|
||||
cache_type = shared.args.cache_type
|
||||
if shared.args.compress_pos_emb != 1:
|
||||
cmd += ["--rope-freq-scale", str(1.0 / shared.args.compress_pos_emb)]
|
||||
if shared.args.rope_freq_base > 0:
|
||||
cmd += ["--rope-freq-base", str(shared.args.rope_freq_base)]
|
||||
if shared.args.mmproj not in [None, 'None']:
|
||||
path = Path(shared.args.mmproj)
|
||||
if not path.exists():
|
||||
|
|
@ -455,7 +476,7 @@ class LlamaServer:
|
|||
print()
|
||||
|
||||
gpu_layers_str = "auto" if shared.args.gpu_layers < 0 else str(shared.args.gpu_layers)
|
||||
ctx_size_str = "auto" if shared.args.ctx_size == 0 else str(shared.args.ctx_size)
|
||||
ctx_size_str = "auto" if shared.args.ctx_size == 0 and shared.args.gpu_layers < 0 else str(shared.args.ctx_size or 8192)
|
||||
logger.info(f"Using gpu_layers={gpu_layers_str} | ctx_size={ctx_size_str} | cache_type={cache_type}")
|
||||
# Start the server with pipes for output
|
||||
self.process = subprocess.Popen(
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
import functools
|
||||
from collections import OrderedDict
|
||||
|
||||
import gradio as gr
|
||||
|
||||
loaders_and_params = OrderedDict({
|
||||
'llama.cpp': [
|
||||
'gpu_layers',
|
||||
|
|
@ -17,8 +15,6 @@ loaders_and_params = OrderedDict({
|
|||
'tensor_split',
|
||||
'extra_flags',
|
||||
'streaming_llm',
|
||||
'rope_freq_base',
|
||||
'compress_pos_emb',
|
||||
'row_split',
|
||||
'no_kv_offload',
|
||||
'no_mmap',
|
||||
|
|
@ -43,8 +39,6 @@ loaders_and_params = OrderedDict({
|
|||
'Transformers': [
|
||||
'gpu_split',
|
||||
'cpu_memory',
|
||||
'alpha_value',
|
||||
'compress_pos_emb',
|
||||
'compute_dtype',
|
||||
'quant_type',
|
||||
'load_in_8bit',
|
||||
|
|
@ -71,7 +65,6 @@ loaders_and_params = OrderedDict({
|
|||
'gpu_split',
|
||||
'model_draft',
|
||||
'draft_max',
|
||||
'ctx_size_draft',
|
||||
'speculative_decoding_accordion',
|
||||
'enable_tp',
|
||||
'tp_backend',
|
||||
|
|
@ -208,6 +201,7 @@ loaders_samplers = {
|
|||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'enable_thinking',
|
||||
'reasoning_effort',
|
||||
'seed',
|
||||
'skip_special_tokens',
|
||||
},
|
||||
|
|
@ -244,6 +238,7 @@ loaders_samplers = {
|
|||
'reasoning_effort',
|
||||
'seed',
|
||||
'sampler_priority',
|
||||
'custom_token_bans',
|
||||
'dry_sequence_breakers',
|
||||
'grammar_string',
|
||||
'grammar_file_row',
|
||||
|
|
@ -277,6 +272,7 @@ def list_all_samplers():
|
|||
|
||||
|
||||
def blacklist_samplers(loader, dynamic_temperature):
|
||||
import gradio as gr
|
||||
all_samplers = list_all_samplers()
|
||||
output = []
|
||||
|
||||
|
|
@ -302,7 +298,58 @@ def get_all_params():
|
|||
return sorted(all_params)
|
||||
|
||||
|
||||
def list_model_elements():
|
||||
return [
|
||||
'filter_by_loader',
|
||||
'loader',
|
||||
'cpu_memory',
|
||||
'gpu_layers',
|
||||
'fit_target',
|
||||
'cpu_moe',
|
||||
'threads',
|
||||
'threads_batch',
|
||||
'batch_size',
|
||||
'ubatch_size',
|
||||
'ctx_size',
|
||||
'cache_type',
|
||||
'tensor_split',
|
||||
'extra_flags',
|
||||
'streaming_llm',
|
||||
'gpu_split',
|
||||
'compute_dtype',
|
||||
'quant_type',
|
||||
'load_in_8bit',
|
||||
'load_in_4bit',
|
||||
'attn_implementation',
|
||||
'cpu',
|
||||
'disk',
|
||||
'row_split',
|
||||
'no_kv_offload',
|
||||
'no_mmap',
|
||||
'mlock',
|
||||
'numa',
|
||||
'parallel',
|
||||
'use_double_quant',
|
||||
'bf16',
|
||||
'enable_tp',
|
||||
'tp_backend',
|
||||
'cfg_cache',
|
||||
'no_use_fast',
|
||||
'model_draft',
|
||||
'draft_max',
|
||||
'gpu_layers_draft',
|
||||
'device_draft',
|
||||
'ctx_size_draft',
|
||||
'spec_type',
|
||||
'spec_ngram_size_n',
|
||||
'spec_ngram_size_m',
|
||||
'spec_ngram_min_hits',
|
||||
'mmproj',
|
||||
]
|
||||
|
||||
|
||||
def make_loader_params_visible(loader):
|
||||
import gradio as gr
|
||||
params = []
|
||||
all_params = get_all_params()
|
||||
if loader in loaders_and_params:
|
||||
|
|
|
|||
|
|
@ -38,6 +38,9 @@ def load_model(model_name, loader=None):
|
|||
sampler_hijack.hijack_samplers()
|
||||
|
||||
shared.args.loader = loader
|
||||
if loader != 'llama.cpp' and shared.args.ctx_size == 0:
|
||||
shared.args.ctx_size = 8192
|
||||
|
||||
output = load_func_map[loader](model_name)
|
||||
if type(output) is tuple:
|
||||
model, tokenizer = output
|
||||
|
|
@ -54,6 +57,8 @@ def load_model(model_name, loader=None):
|
|||
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp':
|
||||
if shared.args.ctx_size > 0:
|
||||
shared.settings['truncation_length'] = shared.args.ctx_size
|
||||
elif loader == 'llama.cpp' and hasattr(model, 'n_ctx') and model.n_ctx:
|
||||
shared.settings['truncation_length'] = model.n_ctx
|
||||
|
||||
shared.is_multimodal = False
|
||||
if loader.lower() in ('exllamav3', 'llama.cpp') and hasattr(model, 'is_multimodal'):
|
||||
|
|
|
|||
|
|
@ -4,10 +4,9 @@ import re
|
|||
from math import floor
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import yaml
|
||||
|
||||
from modules import chat, loaders, metadata_gguf, shared, ui
|
||||
from modules import loaders, metadata_gguf, shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.utils import resolve_model_path
|
||||
|
||||
|
|
@ -16,9 +15,6 @@ def get_fallback_settings():
|
|||
return {
|
||||
'bf16': False,
|
||||
'ctx_size': 8192,
|
||||
'rope_freq_base': 0,
|
||||
'compress_pos_emb': 1,
|
||||
'alpha_value': 1,
|
||||
'truncation_length': shared.settings['truncation_length'],
|
||||
'truncation_length_info': shared.settings['truncation_length'],
|
||||
'skip_special_tokens': shared.settings['skip_special_tokens'],
|
||||
|
|
@ -68,14 +64,8 @@ def get_model_metadata(model):
|
|||
|
||||
for k in metadata:
|
||||
if k.endswith('.context_length'):
|
||||
model_settings['ctx_size'] = min(metadata[k], 8192)
|
||||
model_settings['ctx_size'] = 0
|
||||
model_settings['truncation_length_info'] = metadata[k]
|
||||
elif k.endswith('rope.freq_base'):
|
||||
model_settings['rope_freq_base'] = metadata[k]
|
||||
elif k.endswith('rope.scale_linear'):
|
||||
model_settings['compress_pos_emb'] = metadata[k]
|
||||
elif k.endswith('rope.scaling.factor'):
|
||||
model_settings['compress_pos_emb'] = metadata[k]
|
||||
elif k.endswith('.block_count'):
|
||||
model_settings['gpu_layers'] = -1
|
||||
model_settings['max_gpu_layers'] = metadata[k] + 1
|
||||
|
|
@ -120,15 +110,6 @@ def get_model_metadata(model):
|
|||
model_settings['ctx_size'] = min(value, 8192)
|
||||
break
|
||||
|
||||
if 'rope_theta' in metadata:
|
||||
model_settings['rope_freq_base'] = metadata['rope_theta']
|
||||
elif 'attn_config' in metadata and 'rope_theta' in metadata['attn_config']:
|
||||
model_settings['rope_freq_base'] = metadata['attn_config']['rope_theta']
|
||||
|
||||
if 'rope_scaling' in metadata and isinstance(metadata['rope_scaling'], dict) and all(key in metadata['rope_scaling'] for key in ('type', 'factor')):
|
||||
if metadata['rope_scaling']['type'] == 'linear':
|
||||
model_settings['compress_pos_emb'] = metadata['rope_scaling']['factor']
|
||||
|
||||
if 'torch_dtype' in metadata and metadata['torch_dtype'] == 'bfloat16':
|
||||
model_settings['bf16'] = True
|
||||
|
||||
|
|
@ -182,10 +163,6 @@ def get_model_metadata(model):
|
|||
if 'instruction_template' not in model_settings:
|
||||
model_settings['instruction_template'] = 'Alpaca'
|
||||
|
||||
# Ignore rope_freq_base if set to the default value
|
||||
if 'rope_freq_base' in model_settings and model_settings['rope_freq_base'] == 10000:
|
||||
model_settings.pop('rope_freq_base')
|
||||
|
||||
# Apply user settings from user_data/models/config-user.yaml
|
||||
settings = shared.user_config
|
||||
for pat in settings:
|
||||
|
|
@ -199,7 +176,7 @@ def get_model_metadata(model):
|
|||
|
||||
# Load instruction template if defined by name rather than by value
|
||||
if model_settings['instruction_template'] != 'Custom (obtained from model metadata)':
|
||||
model_settings['instruction_template_str'] = chat.load_instruction_template(model_settings['instruction_template'])
|
||||
model_settings['instruction_template_str'] = load_instruction_template(model_settings['instruction_template'])
|
||||
|
||||
return model_settings
|
||||
|
||||
|
|
@ -228,7 +205,7 @@ def update_model_parameters(state, initial=False):
|
|||
'''
|
||||
UI: update the command-line arguments based on the interface values
|
||||
'''
|
||||
elements = ui.list_model_elements() # the names of the parameters
|
||||
elements = loaders.list_model_elements() # the names of the parameters
|
||||
|
||||
for i, element in enumerate(elements):
|
||||
if element not in state:
|
||||
|
|
@ -248,6 +225,7 @@ def apply_model_settings_to_state(model, state):
|
|||
'''
|
||||
UI: update the state variable with the model settings
|
||||
'''
|
||||
import gradio as gr
|
||||
model_settings = get_model_metadata(model)
|
||||
if 'loader' in model_settings:
|
||||
loader = model_settings.pop('loader')
|
||||
|
|
@ -290,7 +268,7 @@ def save_model_settings(model, state):
|
|||
if model_regex not in user_config:
|
||||
user_config[model_regex] = {}
|
||||
|
||||
for k in ui.list_model_elements():
|
||||
for k in loaders.list_model_elements():
|
||||
if k == 'loader' or k in loaders.loaders_and_params[state['loader']]:
|
||||
user_config[model_regex][k] = state[k]
|
||||
|
||||
|
|
@ -419,3 +397,103 @@ def update_gpu_layers_and_vram(loader, model, gpu_layers, ctx_size, cache_type):
|
|||
|
||||
vram_usage = estimate_vram(model, gpu_layers, ctx_size, cache_type)
|
||||
return f"<div id=\"vram-info\"'>Estimated VRAM to load the model: <span class=\"value\">{vram_usage:.0f} MiB</span></div>"
|
||||
|
||||
|
||||
def load_instruction_template(template):
|
||||
if template == 'None':
|
||||
return ''
|
||||
|
||||
for filepath in [shared.user_data_dir / 'instruction-templates' / f'{template}.yaml', shared.user_data_dir / 'instruction-templates' / 'Alpaca.yaml']:
|
||||
if filepath.exists():
|
||||
break
|
||||
else:
|
||||
return ''
|
||||
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
file_contents = f.read()
|
||||
data = yaml.safe_load(file_contents)
|
||||
if 'instruction_template' in data:
|
||||
return data['instruction_template']
|
||||
else:
|
||||
return _jinja_template_from_old_format(data)
|
||||
|
||||
|
||||
def _jinja_template_from_old_format(params, verbose=False):
|
||||
MASTER_TEMPLATE = """
|
||||
{%- set ns = namespace(found=false) -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'system' -%}
|
||||
{%- set ns.found = true -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if not ns.found -%}
|
||||
{{- '<|PRE-SYSTEM|>' + '<|SYSTEM-MESSAGE|>' + '<|POST-SYSTEM|>' -}}
|
||||
{%- endif %}
|
||||
{%- for message in messages %}
|
||||
{%- if message['role'] == 'system' -%}
|
||||
{{- '<|PRE-SYSTEM|>' + message['content'] + '<|POST-SYSTEM|>' -}}
|
||||
{%- else -%}
|
||||
{%- if message['role'] == 'user' -%}
|
||||
{{-'<|PRE-USER|>' + message['content'] + '<|POST-USER|>'-}}
|
||||
{%- else -%}
|
||||
{{-'<|PRE-ASSISTANT|>' + message['content'] + '<|POST-ASSISTANT|>' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{-'<|PRE-ASSISTANT-GENERATE|>'-}}
|
||||
{%- endif -%}
|
||||
"""
|
||||
|
||||
if 'context' in params and '<|system-message|>' in params['context']:
|
||||
pre_system = params['context'].split('<|system-message|>')[0]
|
||||
post_system = params['context'].split('<|system-message|>')[1]
|
||||
else:
|
||||
pre_system = ''
|
||||
post_system = ''
|
||||
|
||||
pre_user = params['turn_template'].split('<|user-message|>')[0].replace('<|user|>', params['user'])
|
||||
post_user = params['turn_template'].split('<|user-message|>')[1].split('<|bot|>')[0]
|
||||
|
||||
pre_assistant = '<|bot|>' + params['turn_template'].split('<|bot-message|>')[0].split('<|bot|>')[1]
|
||||
pre_assistant = pre_assistant.replace('<|bot|>', params['bot'])
|
||||
post_assistant = params['turn_template'].split('<|bot-message|>')[1]
|
||||
|
||||
def preprocess(string):
|
||||
return string.replace('\n', '\\n').replace('\'', '\\\'')
|
||||
|
||||
pre_system = preprocess(pre_system)
|
||||
post_system = preprocess(post_system)
|
||||
pre_user = preprocess(pre_user)
|
||||
post_user = preprocess(post_user)
|
||||
pre_assistant = preprocess(pre_assistant)
|
||||
post_assistant = preprocess(post_assistant)
|
||||
|
||||
if verbose:
|
||||
print(
|
||||
'\n',
|
||||
repr(pre_system) + '\n',
|
||||
repr(post_system) + '\n',
|
||||
repr(pre_user) + '\n',
|
||||
repr(post_user) + '\n',
|
||||
repr(pre_assistant) + '\n',
|
||||
repr(post_assistant) + '\n',
|
||||
)
|
||||
|
||||
result = MASTER_TEMPLATE
|
||||
if 'system_message' in params:
|
||||
result = result.replace('<|SYSTEM-MESSAGE|>', preprocess(params['system_message']))
|
||||
else:
|
||||
result = result.replace('<|SYSTEM-MESSAGE|>', '')
|
||||
|
||||
result = result.replace('<|PRE-SYSTEM|>', pre_system)
|
||||
result = result.replace('<|POST-SYSTEM|>', post_system)
|
||||
result = result.replace('<|PRE-USER|>', pre_user)
|
||||
result = result.replace('<|POST-USER|>', post_user)
|
||||
result = result.replace('<|PRE-ASSISTANT|>', pre_assistant)
|
||||
result = result.replace('<|PRE-ASSISTANT-GENERATE|>', pre_assistant.rstrip(' '))
|
||||
result = result.replace('<|POST-ASSISTANT|>', post_assistant)
|
||||
|
||||
result = result.strip()
|
||||
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -16,9 +16,10 @@ default_preset_values = {
|
|||
'dynatemp_exponent': 1,
|
||||
'smoothing_factor': 0,
|
||||
'smoothing_curve': 1,
|
||||
'min_p': 0,
|
||||
'top_p': 1,
|
||||
'top_k': 0,
|
||||
'min_p': 0,
|
||||
'top_n_sigma': 0,
|
||||
'typical_p': 1,
|
||||
'xtc_threshold': 0.1,
|
||||
'xtc_probability': 0,
|
||||
|
|
@ -26,7 +27,6 @@ default_preset_values = {
|
|||
'eta_cutoff': 0,
|
||||
'tfs': 1,
|
||||
'top_a': 0,
|
||||
'top_n_sigma': 0,
|
||||
'adaptive_target': 0,
|
||||
'adaptive_decay': 0.9,
|
||||
'dry_multiplier': 0,
|
||||
|
|
|
|||
94
modules/reasoning.py
Normal file
94
modules/reasoning.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
import html as html_module
|
||||
|
||||
# Thinking block format definitions: (start_tag, end_tag, content_start_tag)
|
||||
# Use None for start_tag to match from beginning (end-only formats should be listed last)
|
||||
THINKING_FORMATS = [
|
||||
('<think>', '</think>', None),
|
||||
('<|channel|>analysis<|message|>', '<|end|>', '<|channel|>final<|message|>'),
|
||||
('<|channel|>commentary<|message|>', '<|end|>', '<|channel|>final<|message|>'),
|
||||
('<seed:think>', '</seed:think>', None),
|
||||
('<|think|>', '<|end|>', '<|content|>'), # Solar Open
|
||||
# ('Thinking Process:', '</think>', None), # Qwen3.5 verbose thinking outside tags -- removed: too prone to false positives in streaming
|
||||
(None, '</think>', None), # End-only variant (e.g., Qwen3-next)
|
||||
]
|
||||
|
||||
|
||||
def extract_reasoning(text, html_escaped=False):
|
||||
"""Extract reasoning/thinking blocks from the beginning of a string.
|
||||
|
||||
When html_escaped=True, tags are HTML-escaped before searching
|
||||
(for use on already-escaped UI strings).
|
||||
|
||||
Returns (reasoning_content, final_content) where reasoning_content is
|
||||
None if no thinking block is found.
|
||||
"""
|
||||
if not text:
|
||||
return None, text
|
||||
|
||||
esc = html_module.escape if html_escaped else lambda s: s
|
||||
|
||||
for start_tag, end_tag, content_tag in THINKING_FORMATS:
|
||||
end_esc = esc(end_tag)
|
||||
content_esc = esc(content_tag) if content_tag else None
|
||||
|
||||
if start_tag is None:
|
||||
# End-only format: require end tag, start from beginning
|
||||
end_pos = text.find(end_esc)
|
||||
if end_pos == -1:
|
||||
continue
|
||||
thought_start = 0
|
||||
else:
|
||||
# Normal format: require start tag
|
||||
start_esc = esc(start_tag)
|
||||
start_pos = text.find(start_esc)
|
||||
if start_pos == -1:
|
||||
# During streaming, the start tag may be arriving partially.
|
||||
# If the text is a prefix of a start tag, return empty content
|
||||
# to prevent the partial tag from leaking.
|
||||
stripped = text.strip()
|
||||
if stripped and start_esc.startswith(stripped):
|
||||
return '', ''
|
||||
continue
|
||||
thought_start = start_pos + len(start_esc)
|
||||
end_pos = text.find(end_esc, thought_start)
|
||||
|
||||
if end_pos == -1:
|
||||
# End tag missing - check if content tag can serve as fallback
|
||||
if content_esc:
|
||||
content_pos = text.find(content_esc, thought_start)
|
||||
if content_pos != -1:
|
||||
thought_end = content_pos
|
||||
content_start = content_pos + len(content_esc)
|
||||
else:
|
||||
thought_end = len(text)
|
||||
content_start = len(text)
|
||||
else:
|
||||
thought_end = len(text)
|
||||
content_start = len(text)
|
||||
else:
|
||||
thought_end = end_pos
|
||||
if content_esc:
|
||||
content_pos = text.find(content_esc, end_pos)
|
||||
if content_pos != -1:
|
||||
content_start = content_pos + len(content_esc)
|
||||
else:
|
||||
# Content tag expected but not yet present (e.g. partial
|
||||
# streaming) — suppress intermediate tags between end_tag
|
||||
# and content_tag so they don't leak as content.
|
||||
content_start = len(text)
|
||||
else:
|
||||
content_start = end_pos + len(end_esc)
|
||||
|
||||
return text[thought_start:thought_end], text[content_start:]
|
||||
|
||||
# Handle standalone GPT-OSS final channel marker without a preceding
|
||||
# analysis/commentary block (the model skipped thinking entirely).
|
||||
for marker in ['<|start|>assistant<|channel|>final<|message|>', '<|channel|>final<|message|>']:
|
||||
marker_esc = esc(marker)
|
||||
pos = text.find(marker_esc)
|
||||
if pos != -1:
|
||||
before = text[:pos].strip()
|
||||
after = text[pos + len(marker_esc):]
|
||||
return (before if before else None), after
|
||||
|
||||
return None, text
|
||||
|
|
@ -47,7 +47,7 @@ parser = argparse.ArgumentParser(description="Text Generation Web UI", conflict_
|
|||
# Basic settings
|
||||
group = parser.add_argument_group('Basic settings')
|
||||
group.add_argument('--user-data-dir', type=str, default=str(user_data_dir), help='Path to the user data directory. Default: auto-detected.')
|
||||
group.add_argument('--multi-user', action='store_true', help='Multi-user mode. Chat histories are not saved or automatically loaded. Warning: this is likely not safe for sharing publicly.')
|
||||
group.add_argument('--multi-user', action='store_true', help='Multi-user mode. Chat histories are not saved or automatically loaded. Best suited for small trusted teams.')
|
||||
group.add_argument('--model', type=str, help='Name of the model to load by default.')
|
||||
group.add_argument('--lora', type=str, nargs='+', help='The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.')
|
||||
group.add_argument('--model-dir', type=str, default=str(user_data_dir / 'models'), help='Path to directory with all the models.')
|
||||
|
|
@ -76,7 +76,7 @@ group.add_argument('--loader', type=str, help='Choose the model loader manually,
|
|||
|
||||
# Cache
|
||||
group = parser.add_argument_group('Context and cache')
|
||||
group.add_argument('--ctx-size', '--n_ctx', '--max_seq_len', type=int, default=8192, metavar='N', help='Context size in tokens. llama.cpp: 0 = auto if gpu-layers is also -1.')
|
||||
group.add_argument('--ctx-size', '--n_ctx', '--max_seq_len', type=int, default=0, metavar='N', help='Context size in tokens. 0 = auto for llama.cpp (requires gpu-layers=-1), 8192 for other loaders.')
|
||||
group.add_argument('--cache-type', '--cache_type', type=str, default='fp16', metavar='N', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8).')
|
||||
|
||||
# Speculative decoding
|
||||
|
|
@ -108,7 +108,7 @@ group.add_argument('--threads', type=int, default=0, help='Number of threads to
|
|||
group.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.')
|
||||
group.add_argument('--numa', action='store_true', help='Activate NUMA task allocation for llama.cpp.')
|
||||
group.add_argument('--parallel', type=int, default=1, help='Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set ctx_size to 32768.')
|
||||
group.add_argument('--fit-target', type=str, default='1024', help='Target VRAM margin per device for auto GPU layers, comma-separated list of values in MiB. A single value is broadcast across all devices. Default: 1024.')
|
||||
group.add_argument('--fit-target', type=str, default='512', help='Target VRAM margin per device for auto GPU layers, comma-separated list of values in MiB. A single value is broadcast across all devices.')
|
||||
group.add_argument('--extra-flags', type=str, default=None, help='Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"')
|
||||
|
||||
# Transformers/Accelerate
|
||||
|
|
@ -139,12 +139,6 @@ group.add_argument('--enable-tp', '--enable_tp', action='store_true', help='Enab
|
|||
group.add_argument('--tp-backend', type=str, default='native', help='The backend for tensor parallelism. Valid options: native, nccl. Default: native.')
|
||||
group.add_argument('--cfg-cache', action='store_true', help='Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.')
|
||||
|
||||
# RoPE
|
||||
group = parser.add_argument_group('RoPE')
|
||||
group.add_argument('--alpha_value', type=float, default=1, help='Positional embeddings alpha factor for NTK RoPE scaling. Use either this or compress_pos_emb, not both.')
|
||||
group.add_argument('--rope_freq_base', type=int, default=0, help='If greater than 0, will be used instead of alpha_value. Those two are related by rope_freq_base = 10000 * alpha_value ^ (64 / 63).')
|
||||
group.add_argument('--compress_pos_emb', type=int, default=1, help="Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.")
|
||||
|
||||
# Gradio
|
||||
group = parser.add_argument_group('Gradio')
|
||||
group.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
|
||||
|
|
@ -163,7 +157,7 @@ group.add_argument('--portable', action='store_true', help='Hide features not av
|
|||
# API
|
||||
group = parser.add_argument_group('API')
|
||||
group.add_argument('--api', action='store_true', help='Enable the API extension.')
|
||||
group.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
|
||||
group.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudflare.')
|
||||
group.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None)
|
||||
group.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.')
|
||||
group.add_argument('--api-key', type=str, default='', help='API authentication key.')
|
||||
|
|
@ -181,9 +175,10 @@ group.add_argument('--dynatemp-high', type=float, default=_d['dynatemp_high'], m
|
|||
group.add_argument('--dynatemp-exponent', type=float, default=_d['dynatemp_exponent'], metavar='N', help='Dynamic temperature exponent')
|
||||
group.add_argument('--smoothing-factor', type=float, default=_d['smoothing_factor'], metavar='N', help='Smoothing factor')
|
||||
group.add_argument('--smoothing-curve', type=float, default=_d['smoothing_curve'], metavar='N', help='Smoothing curve')
|
||||
group.add_argument('--min-p', type=float, default=_d['min_p'], metavar='N', help='Min P')
|
||||
group.add_argument('--top-p', type=float, default=_d['top_p'], metavar='N', help='Top P')
|
||||
group.add_argument('--top-k', type=int, default=_d['top_k'], metavar='N', help='Top K')
|
||||
group.add_argument('--min-p', type=float, default=_d['min_p'], metavar='N', help='Min P')
|
||||
group.add_argument('--top-n-sigma', type=float, default=_d['top_n_sigma'], metavar='N', help='Top N Sigma')
|
||||
group.add_argument('--typical-p', type=float, default=_d['typical_p'], metavar='N', help='Typical P')
|
||||
group.add_argument('--xtc-threshold', type=float, default=_d['xtc_threshold'], metavar='N', help='XTC threshold')
|
||||
group.add_argument('--xtc-probability', type=float, default=_d['xtc_probability'], metavar='N', help='XTC probability')
|
||||
|
|
@ -191,7 +186,6 @@ group.add_argument('--epsilon-cutoff', type=float, default=_d['epsilon_cutoff'],
|
|||
group.add_argument('--eta-cutoff', type=float, default=_d['eta_cutoff'], metavar='N', help='Eta cutoff')
|
||||
group.add_argument('--tfs', type=float, default=_d['tfs'], metavar='N', help='TFS')
|
||||
group.add_argument('--top-a', type=float, default=_d['top_a'], metavar='N', help='Top A')
|
||||
group.add_argument('--top-n-sigma', type=float, default=_d['top_n_sigma'], metavar='N', help='Top N Sigma')
|
||||
group.add_argument('--adaptive-target', type=float, default=_d['adaptive_target'], metavar='N', help='Adaptive target')
|
||||
group.add_argument('--adaptive-decay', type=float, default=_d['adaptive_decay'], metavar='N', help='Adaptive decay')
|
||||
group.add_argument('--dry-multiplier', type=float, default=_d['dry_multiplier'], metavar='N', help='DRY multiplier')
|
||||
|
|
@ -263,8 +257,9 @@ settings = {
|
|||
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>". Reply directly, without starting the reply with the character name.\n\n<|prompt|>',
|
||||
'enable_web_search': False,
|
||||
'web_search_pages': 3,
|
||||
'selected_tools': [],
|
||||
'prompt-notebook': '',
|
||||
'preset': 'Qwen3 - Thinking' if (user_data_dir / 'presets/Qwen3 - Thinking.yaml').exists() else None,
|
||||
'preset': 'Top-P' if (user_data_dir / 'presets/Top-P.yaml').exists() else None,
|
||||
'max_new_tokens': 512,
|
||||
'max_new_tokens_min': 1,
|
||||
'max_new_tokens_max': 4096,
|
||||
|
|
@ -289,7 +284,7 @@ settings = {
|
|||
'include_past_attachments': True,
|
||||
|
||||
# Generation parameters - Curve shape
|
||||
'temperature': 0.6,
|
||||
'temperature': neutral_samplers['temperature'],
|
||||
'dynatemp_low': neutral_samplers['dynatemp_low'],
|
||||
'dynatemp_high': neutral_samplers['dynatemp_high'],
|
||||
'dynatemp_exponent': neutral_samplers['dynatemp_exponent'],
|
||||
|
|
@ -297,9 +292,10 @@ settings = {
|
|||
'smoothing_curve': neutral_samplers['smoothing_curve'],
|
||||
|
||||
# Generation parameters - Curve cutoff
|
||||
'min_p': neutral_samplers['min_p'],
|
||||
'top_p': 0.95,
|
||||
'top_k': 20,
|
||||
'top_k': neutral_samplers['top_k'],
|
||||
'min_p': neutral_samplers['min_p'],
|
||||
'top_n_sigma': neutral_samplers['top_n_sigma'],
|
||||
'typical_p': neutral_samplers['typical_p'],
|
||||
'xtc_threshold': neutral_samplers['xtc_threshold'],
|
||||
'xtc_probability': neutral_samplers['xtc_probability'],
|
||||
|
|
@ -307,7 +303,6 @@ settings = {
|
|||
'eta_cutoff': neutral_samplers['eta_cutoff'],
|
||||
'tfs': neutral_samplers['tfs'],
|
||||
'top_a': neutral_samplers['top_a'],
|
||||
'top_n_sigma': neutral_samplers['top_n_sigma'],
|
||||
'adaptive_target': neutral_samplers['adaptive_target'],
|
||||
'adaptive_decay': neutral_samplers['adaptive_decay'],
|
||||
|
||||
|
|
@ -347,7 +342,7 @@ settings = {
|
|||
'greeting': 'How can I help you today?',
|
||||
'custom_system_message': '',
|
||||
'instruction_template_str': "{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if not ns.found -%}\n {{- '' + 'Below is an instruction that describes a task. Write a response that appropriately completes the request.' + '\\n\\n' -}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {{- '' + message['content'] + '\\n\\n' -}}\n {%- else -%}\n {%- if message['role'] == 'user' -%}\n {{-'### Instruction:\\n' + message['content'] + '\\n\\n'-}}\n {%- else -%}\n {{-'### Response:\\n' + message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{-'### Response:\\n'-}}\n{%- endif -%}",
|
||||
'chat_template_str': "{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {%- if message['content'] -%}\n {{- message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- if user_bio -%}\n {{- user_bio + '\\n\\n' -}}\n {%- endif -%}\n {%- else -%}\n {%- if message['role'] == 'user' -%}\n {{- name1 + ': ' + message['content'] + '\\n'-}}\n {%- else -%}\n {{- name2 + ': ' + message['content'] + '\\n' -}}\n {%- endif -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt %}\n {{- name2 + ':' -}}\n{%- endif %}",
|
||||
'chat_template_str': "{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {%- if message['content'] -%}\n {{- message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- if user_bio -%}\n {{- user_bio + '\\n\\n' -}}\n {%- endif -%}\n {%- elif message['role'] == 'tool' -%}\n {{- '[Tool result: ' + message['content'] + ']\\n' -}}\n {%- elif message['role'] == 'user' -%}\n {{- name1 + ': ' + message['content'] + '\\n'-}}\n {%- elif message['tool_calls'] is defined and message['tool_calls'] -%}\n {%- for tc in message['tool_calls'] -%}\n {{- '[Calling: ' + tc['function']['name'] + '(' + tc['function']['arguments'] + ')]\\n' -}}\n {%- endfor -%}\n {%- else -%}\n {{- name2 + ': ' + message['content'] + '\\n' -}}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt %}\n {{- name2 + ':' -}}\n{%- endif %}",
|
||||
|
||||
# Extensions
|
||||
'default_extensions': [],
|
||||
|
|
@ -395,9 +390,16 @@ def do_cmd_flags_warnings():
|
|||
if args.share:
|
||||
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
|
||||
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)):
|
||||
logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
|
||||
if args.multi_user:
|
||||
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.')
|
||||
logger.warning("You are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
|
||||
if args.multi_user:
|
||||
logger.warning(
|
||||
'Multi-user mode is enabled. Known limitations:'
|
||||
'\n- The Stop button stops generation for all users, not just you.'
|
||||
'\n- Chat history is not saved and will be lost on page refresh.'
|
||||
'\n- Only one user can generate at a time unless using a parallel-capable backend (e.g. llama.cpp with --parallel N for N > 1, or ExLlamaV3).'
|
||||
'\n\nThis mode works best for small trusted teams.'
|
||||
'\n\nDo not expose publicly. Grayed-out actions can easily be bypassed client-side.\n'
|
||||
)
|
||||
|
||||
|
||||
def apply_image_model_cli_overrides():
|
||||
|
|
|
|||
|
|
@ -78,10 +78,13 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
|||
reply = ''
|
||||
is_stream = state['stream']
|
||||
if len(all_stop_strings) > 0 and not state['stream']:
|
||||
original_logits_processor = state.get('logits_processor')
|
||||
stop_event_ref = state.pop('stop_event', None)
|
||||
state = copy.deepcopy(state)
|
||||
if stop_event_ref is not None:
|
||||
state['stop_event'] = stop_event_ref
|
||||
if original_logits_processor is not None:
|
||||
state['logits_processor'] = original_logits_processor
|
||||
state['stream'] = True
|
||||
|
||||
# Generate
|
||||
|
|
@ -375,7 +378,7 @@ def generate_reply_HF(question, original_question, state, stopping_strings=None,
|
|||
generate_params['sampler_priority'] = [x.strip() for x in state['sampler_priority'].replace('\n', ',').split(',') if x.strip()]
|
||||
|
||||
if state['custom_token_bans']:
|
||||
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
|
||||
to_ban = [int(x.strip()) for x in state['custom_token_bans'].split(',') if x.strip()]
|
||||
if len(to_ban) > 0:
|
||||
if generate_params.get('suppress_tokens', None):
|
||||
generate_params['suppress_tokens'] += to_ban
|
||||
|
|
|
|||
667
modules/tool_parsing.py
Normal file
667
modules/tool_parsing.py
Normal file
|
|
@ -0,0 +1,667 @@
|
|||
import json
|
||||
import random
|
||||
import re
|
||||
|
||||
|
||||
def get_tool_call_id() -> str:
|
||||
letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b = [random.choice(letter_bytes) for _ in range(8)]
|
||||
return "call_" + "".join(b).lower()
|
||||
|
||||
|
||||
# All known opening markers for tool calls across model formats.
|
||||
TOOL_CALL_OPENING_MARKERS = [
|
||||
'<tool_call>',
|
||||
'<function_call>',
|
||||
'<minimax:tool_call>',
|
||||
'<|tool_call_begin|>',
|
||||
'<|tool_calls_section_begin|>',
|
||||
'<|tool▁call▁begin|>',
|
||||
'<|tool▁calls▁begin|>',
|
||||
'[TOOL_CALLS]',
|
||||
'to=functions.',
|
||||
'<|channel|>commentary',
|
||||
]
|
||||
|
||||
|
||||
def streaming_tool_buffer_check(text, markers=None, tool_names=None, check_bare_names=False):
|
||||
'''
|
||||
Check whether streaming output should be withheld because it may
|
||||
contain tool-call markup.
|
||||
|
||||
Args:
|
||||
text: Full accumulated internal text.
|
||||
markers: Template-specific markers for partial-prefix matching.
|
||||
If None, falls back to TOOL_CALL_OPENING_MARKERS.
|
||||
tool_names: List of tool function names.
|
||||
check_bare_names: Whether to do partial-prefix matching on tool
|
||||
names (for models with unknown template format).
|
||||
'''
|
||||
# Full marker found in text → buffer permanently.
|
||||
# Always checks ALL known markers regardless of template (cheap safety net).
|
||||
for marker in TOOL_CALL_OPENING_MARKERS:
|
||||
if marker in text:
|
||||
return True
|
||||
|
||||
# Bare function-name full match: "get_weather{...}" or "get_weather {...}"
|
||||
if tool_names:
|
||||
for name in tool_names:
|
||||
if name + '{' in text or name + ' {' in text:
|
||||
return True
|
||||
|
||||
# Partial-prefix matching: only for template-specific markers.
|
||||
for marker in (markers if markers is not None else TOOL_CALL_OPENING_MARKERS):
|
||||
for prefix_len in range(min(len(marker) - 1, len(text)), 0, -1):
|
||||
if text.endswith(marker[:prefix_len]):
|
||||
return True
|
||||
|
||||
# Bare-name partial matching: only when template format is unknown.
|
||||
if check_bare_names and tool_names:
|
||||
for name in tool_names:
|
||||
if text.endswith(name):
|
||||
return True
|
||||
for prefix_len in range(min(len(name) - 1, len(text)), 0, -1):
|
||||
if text.endswith(name[:prefix_len]):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_and_sanitize_tool_call_candidate(candidate_dict: dict, tool_names: list[str]):
|
||||
# check if property 'function' exists and is a dictionary, otherwise adapt dict
|
||||
if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str):
|
||||
candidate_dict = {"type": "function", "function": candidate_dict}
|
||||
if 'function' in candidate_dict and isinstance(candidate_dict['function'], str):
|
||||
candidate_dict['name'] = candidate_dict['function']
|
||||
del candidate_dict['function']
|
||||
candidate_dict = {"type": "function", "function": candidate_dict}
|
||||
if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict):
|
||||
# check if 'name' exists within 'function' and is part of known tools
|
||||
if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names:
|
||||
candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value
|
||||
# map property 'parameters' used by some older models to 'arguments'
|
||||
if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]:
|
||||
candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"]
|
||||
del candidate_dict["function"]["parameters"]
|
||||
return candidate_dict
|
||||
return None
|
||||
|
||||
|
||||
def _extract_balanced_json(text: str, start: int) -> str | None:
|
||||
"""Extract a balanced JSON object from text starting at the given position.
|
||||
|
||||
Walks through the string tracking brace depth and string boundaries
|
||||
to correctly handle arbitrary nesting levels.
|
||||
"""
|
||||
if start >= len(text) or text[start] != '{':
|
||||
return None
|
||||
depth = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
for i in range(start, len(text)):
|
||||
c = text[i]
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
continue
|
||||
if c == '\\' and in_string:
|
||||
escape_next = True
|
||||
continue
|
||||
if c == '"':
|
||||
in_string = not in_string
|
||||
continue
|
||||
if in_string:
|
||||
continue
|
||||
if c == '{':
|
||||
depth += 1
|
||||
elif c == '}':
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start:i + 1]
|
||||
return None
|
||||
|
||||
|
||||
def _parse_channel_tool_calls(answer: str, tool_names: list[str]):
|
||||
"""Parse channel-based tool calls used by GPT-OSS and similar models.
|
||||
|
||||
Format:
|
||||
<|start|>assistant to=functions.func_name<|channel|>commentary json<|message|>{"arg": "value"}
|
||||
or:
|
||||
<|channel|>commentary to=functions.func_name <|constrain|>json<|message|>{"arg": "value"}
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
# Pattern 1: to=functions.NAME before <|channel|> (GPT-OSS primary format)
|
||||
# Pattern 2: to=functions.NAME after <|channel|> (alternative format)
|
||||
patterns = [
|
||||
r'to=functions\.([^<\s]+)\s*<\|channel\|>[^<]*<\|message\|>',
|
||||
r'<\|channel\|>\w+ to=functions\.([^<\s]+).*?<\|message\|>',
|
||||
]
|
||||
for pattern in patterns:
|
||||
for m in re.finditer(pattern, answer):
|
||||
func_name = m.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
json_str = _extract_balanced_json(answer, m.end())
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
if start_pos is None:
|
||||
prefix = answer.rfind('<|start|>assistant', 0, m.start())
|
||||
start_pos = prefix if prefix != -1 else m.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if matches:
|
||||
break
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parse_mistral_token_tool_calls(answer: str, tool_names: list[str]):
|
||||
"""Parse Mistral/Devstral-style tool calls with [TOOL_CALLS] and [ARGS] special tokens.
|
||||
|
||||
Format:
|
||||
[TOOL_CALLS]func_name[ARGS]{"arg": "value"}
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
for m in re.finditer(
|
||||
r'\[TOOL_CALLS\]\s*(\S+?)\s*\[ARGS\]\s*',
|
||||
answer
|
||||
):
|
||||
func_name = m.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
json_str = _extract_balanced_json(answer, m.end())
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
if start_pos is None:
|
||||
start_pos = m.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parse_bare_name_tool_calls(answer: str, tool_names: list[str]):
|
||||
"""Parse bare function-name style tool calls used by Mistral and similar models.
|
||||
|
||||
Format:
|
||||
functionName{"arg": "value"}
|
||||
Multiple calls are concatenated directly or separated by whitespace.
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
# Match tool name followed by opening brace, then extract balanced JSON
|
||||
escaped_names = [re.escape(name) for name in tool_names]
|
||||
pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{'
|
||||
for match in re.finditer(pattern, answer):
|
||||
text = match.group(0)
|
||||
name = None
|
||||
for n in tool_names:
|
||||
if text.startswith(n):
|
||||
name = n
|
||||
break
|
||||
if not name:
|
||||
continue
|
||||
brace_start = match.end() - 1
|
||||
json_str = _extract_balanced_json(answer, brace_start)
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
if start_pos is None:
|
||||
start_pos = match.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parse_xml_param_tool_calls(answer: str, tool_names: list[str]):
|
||||
"""Parse XML-parameter style tool calls used by Qwen3.5 and similar models.
|
||||
|
||||
Format:
|
||||
<tool_call>
|
||||
<function=function_name>
|
||||
<parameter=param_name>value</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
for tc_match in re.finditer(r'<tool_call>\s*(.*?)\s*</tool_call>', answer, re.DOTALL):
|
||||
tc_content = tc_match.group(1)
|
||||
func_match = re.search(r'<function=([^>]+)>', tc_content)
|
||||
if not func_match:
|
||||
continue
|
||||
func_name = func_match.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
arguments = {}
|
||||
for param_match in re.finditer(r'<parameter=([^>]+)>\s*(.*?)\s*</parameter>', tc_content, re.DOTALL):
|
||||
param_name = param_match.group(1).strip()
|
||||
param_value = param_match.group(2).strip()
|
||||
try:
|
||||
param_value = json.loads(param_value)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass # keep as string
|
||||
arguments[param_name] = param_value
|
||||
if start_pos is None:
|
||||
start_pos = tc_match.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parse_kimi_tool_calls(answer: str, tool_names: list[str]):
|
||||
"""Parse Kimi-K2-style tool calls using pipe-delimited tokens.
|
||||
|
||||
Format:
|
||||
<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|>functions.func_name:index<|tool_call_argument_begin|>{"arg": "value"}<|tool_call_end|>
|
||||
<|tool_calls_section_end|>
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
for m in re.finditer(
|
||||
r'<\|tool_call_begin\|>\s*(?:functions\.)?(\S+?)(?::\d+)?\s*<\|tool_call_argument_begin\|>\s*',
|
||||
answer
|
||||
):
|
||||
func_name = m.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
json_str = _extract_balanced_json(answer, m.end())
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
if start_pos is None:
|
||||
# Check for section begin marker before the call marker
|
||||
section = answer.rfind('<|tool_calls_section_begin|>', 0, m.start())
|
||||
start_pos = section if section != -1 else m.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parse_minimax_tool_calls(answer: str, tool_names: list[str]):
|
||||
"""Parse MiniMax-style tool calls using invoke/parameter XML tags.
|
||||
|
||||
Format:
|
||||
<minimax:tool_call>
|
||||
<invoke name="function_name">
|
||||
<parameter name="param_name">value</parameter>
|
||||
</invoke>
|
||||
</minimax:tool_call>
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
for tc_match in re.finditer(r'<minimax:tool_call>\s*(.*?)\s*</minimax:tool_call>', answer, re.DOTALL):
|
||||
tc_content = tc_match.group(1)
|
||||
# Split on <invoke> to handle multiple parallel calls in one block
|
||||
for invoke_match in re.finditer(r'<invoke\s+name="([^"]+)">(.*?)</invoke>', tc_content, re.DOTALL):
|
||||
func_name = invoke_match.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
invoke_body = invoke_match.group(2)
|
||||
arguments = {}
|
||||
for param_match in re.finditer(r'<parameter\s+name="([^"]+)">\s*(.*?)\s*</parameter>', invoke_body, re.DOTALL):
|
||||
param_name = param_match.group(1).strip()
|
||||
param_value = param_match.group(2).strip()
|
||||
try:
|
||||
param_value = json.loads(param_value)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass # keep as string
|
||||
arguments[param_name] = param_value
|
||||
if start_pos is None:
|
||||
start_pos = tc_match.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parse_deep_seek_tool_calls(answer: str, tool_names: list[str]):
|
||||
"""Parse DeepSeek-style tool calls using fullwidth Unicode token delimiters.
|
||||
|
||||
Format:
|
||||
<|tool▁calls▁begin|><|tool▁call▁begin|>func_name<|tool▁sep|>{"arg": "value"}<|tool▁call▁end|><|tool▁calls▁end|>
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
for m in re.finditer(
|
||||
r'<|tool▁call▁begin|>\s*(\S+?)\s*<|tool▁sep|>\s*',
|
||||
answer
|
||||
):
|
||||
func_name = m.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
json_str = _extract_balanced_json(answer, m.end())
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
if start_pos is None:
|
||||
# Check for section begin marker before the call marker
|
||||
section = answer.rfind('<|tool▁calls▁begin|>', 0, m.start())
|
||||
start_pos = section if section != -1 else m.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parse_glm_tool_calls(answer: str, tool_names: list[str]):
|
||||
"""Parse GLM-style tool calls using arg_key/arg_value XML pairs.
|
||||
|
||||
Format:
|
||||
<tool_call>function_name
|
||||
<arg_key>key1</arg_key>
|
||||
<arg_value>value1</arg_value>
|
||||
</tool_call>
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
for tc_match in re.finditer(r'<tool_call>\s*(.*?)\s*</tool_call>', answer, re.DOTALL):
|
||||
tc_content = tc_match.group(1)
|
||||
# First non-tag text is the function name
|
||||
name_match = re.match(r'([^<\s]+)', tc_content.strip())
|
||||
if not name_match:
|
||||
continue
|
||||
func_name = name_match.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
# Extract arg_key/arg_value pairs
|
||||
keys = [k.group(1).strip() for k in re.finditer(r'<arg_key>\s*(.*?)\s*</arg_key>', tc_content, re.DOTALL)]
|
||||
vals = [v.group(1).strip() for v in re.finditer(r'<arg_value>\s*(.*?)\s*</arg_value>', tc_content, re.DOTALL)]
|
||||
if len(keys) != len(vals):
|
||||
continue
|
||||
arguments = {}
|
||||
for k, v in zip(keys, vals):
|
||||
try:
|
||||
v = json.loads(v)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass # keep as string
|
||||
arguments[k] = v
|
||||
if start_pos is None:
|
||||
start_pos = tc_match.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parse_pythonic_tool_calls(answer: str, tool_names: list[str]):
|
||||
"""Parse pythonic-style tool calls used by Llama 4 and similar models.
|
||||
|
||||
Format:
|
||||
[func_name(param1="value1", param2="value2"), func_name2(...)]
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
# Match a bracketed list of function calls
|
||||
bracket_match = re.search(r'\[([^\[\]]+)\]', answer)
|
||||
if not bracket_match:
|
||||
return matches, start_pos
|
||||
|
||||
inner = bracket_match.group(1)
|
||||
|
||||
# Build pattern for known tool names
|
||||
escaped_names = [re.escape(name) for name in tool_names]
|
||||
name_pattern = '|'.join(escaped_names)
|
||||
|
||||
for call_match in re.finditer(
|
||||
r'(' + name_pattern + r')\(([^)]*)\)',
|
||||
inner
|
||||
):
|
||||
func_name = call_match.group(1)
|
||||
params_str = call_match.group(2).strip()
|
||||
arguments = {}
|
||||
|
||||
if params_str:
|
||||
# Parse key="value" pairs, handling commas inside quoted values
|
||||
for param_match in re.finditer(
|
||||
r'(\w+)\s*=\s*("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\)]+)',
|
||||
params_str
|
||||
):
|
||||
param_name = param_match.group(1)
|
||||
param_value = param_match.group(2).strip()
|
||||
# Strip surrounding quotes
|
||||
if (param_value.startswith('"') and param_value.endswith('"')) or \
|
||||
(param_value.startswith("'") and param_value.endswith("'")):
|
||||
param_value = param_value[1:-1]
|
||||
# Try to parse as JSON for numeric/bool/null values
|
||||
try:
|
||||
param_value = json.loads(param_value)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
arguments[param_name] = param_value
|
||||
|
||||
if start_pos is None:
|
||||
start_pos = bracket_match.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
# Format registry: maps template substrings to the parser and streaming
|
||||
# markers for that format. When a format's hints are NOT found in the
|
||||
# template, its parser and markers are excluded.
|
||||
TOOL_CALL_FORMATS = [
|
||||
{
|
||||
'template_hints': ['tool▁call▁begin', 'tool▁calls▁begin'],
|
||||
'parser': _parse_deep_seek_tool_calls,
|
||||
'markers': ['<|tool▁call▁begin|>', '<|tool▁calls▁begin|>'],
|
||||
},
|
||||
{
|
||||
'template_hints': ['<|tool_call_begin|>', 'tool_calls_section'],
|
||||
'parser': _parse_kimi_tool_calls,
|
||||
'markers': ['<|tool_call_begin|>', '<|tool_calls_section_begin|>'],
|
||||
},
|
||||
{
|
||||
'template_hints': ['to=functions.', '<|channel|>'],
|
||||
'parser': _parse_channel_tool_calls,
|
||||
'markers': ['to=functions.', '<|channel|>commentary'],
|
||||
},
|
||||
{
|
||||
'template_hints': ['minimax:tool_call'],
|
||||
'parser': _parse_minimax_tool_calls,
|
||||
'markers': ['<minimax:tool_call>'],
|
||||
},
|
||||
{
|
||||
'template_hints': ['<arg_key>'],
|
||||
'parser': _parse_glm_tool_calls,
|
||||
'markers': ['<tool_call>'],
|
||||
},
|
||||
{
|
||||
'template_hints': ['<tool_call>'],
|
||||
'parser': _parse_xml_param_tool_calls,
|
||||
'markers': ['<tool_call>'],
|
||||
},
|
||||
{
|
||||
'template_hints': ['[TOOL_CALLS]'],
|
||||
'parser': _parse_mistral_token_tool_calls,
|
||||
'markers': ['[TOOL_CALLS]'],
|
||||
},
|
||||
{
|
||||
'template_hints': ['<function_call>'],
|
||||
'parser': None,
|
||||
'markers': ['<function_call>'],
|
||||
},
|
||||
]
|
||||
|
||||
# Default ordered list of all specialized parsers.
|
||||
ALL_PARSERS = [
|
||||
_parse_deep_seek_tool_calls,
|
||||
_parse_kimi_tool_calls,
|
||||
_parse_channel_tool_calls,
|
||||
_parse_minimax_tool_calls,
|
||||
_parse_glm_tool_calls,
|
||||
_parse_xml_param_tool_calls,
|
||||
_parse_mistral_token_tool_calls,
|
||||
_parse_bare_name_tool_calls,
|
||||
_parse_pythonic_tool_calls,
|
||||
]
|
||||
|
||||
|
||||
def detect_tool_call_format(template_str):
|
||||
"""Inspect a chat/instruction template to determine which tool call
|
||||
formats are relevant.
|
||||
|
||||
Uses an exclude-based approach: starts with all parsers/markers,
|
||||
then removes the ones whose hints are not found in the template.
|
||||
|
||||
Returns (parsers, streaming_markers, check_bare_names).
|
||||
"""
|
||||
if not template_str:
|
||||
return None, TOOL_CALL_OPENING_MARKERS, True
|
||||
|
||||
matched_any = False
|
||||
exclude_parsers = []
|
||||
exclude_markers = []
|
||||
matched_markers = []
|
||||
|
||||
for fmt in TOOL_CALL_FORMATS:
|
||||
if any(hint in template_str for hint in fmt['template_hints']):
|
||||
matched_any = True
|
||||
matched_markers.extend(fmt['markers'])
|
||||
else:
|
||||
if fmt['parser'] is not None:
|
||||
exclude_parsers.append(fmt['parser'])
|
||||
exclude_markers.extend(fmt['markers'])
|
||||
|
||||
if not matched_any:
|
||||
return None, TOOL_CALL_OPENING_MARKERS, True
|
||||
|
||||
parsers = [p for p in ALL_PARSERS if p not in exclude_parsers]
|
||||
markers = [m for m in TOOL_CALL_OPENING_MARKERS if m not in exclude_markers or m in matched_markers]
|
||||
|
||||
return parsers, markers, False
|
||||
|
||||
|
||||
def parse_tool_call(answer: str, tool_names: list[str], return_prefix: bool = False, parsers: list = None):
|
||||
matches = []
|
||||
start_pos = None
|
||||
|
||||
def _return(matches, start_pos):
|
||||
if return_prefix:
|
||||
prefix = answer[:start_pos] if matches and start_pos is not None else ''
|
||||
return matches, prefix
|
||||
return matches
|
||||
|
||||
# Try specialized parsers.
|
||||
for parser in (parsers if parsers is not None else ALL_PARSERS):
|
||||
matches, start_pos = parser(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Generic fallback: regex pattern to find the JSON content wrapped in <function>, <tools>, <tool_call>, and other tags observed from various models
|
||||
patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)</\1>"]
|
||||
|
||||
for pattern in patterns:
|
||||
for match in re.finditer(pattern, answer, re.DOTALL):
|
||||
if match.group(2) is None:
|
||||
continue
|
||||
# remove backtick wraps if present
|
||||
candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip())
|
||||
candidate = re.sub(r"```$", "", candidate.strip())
|
||||
# unwrap inner tags
|
||||
candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL)
|
||||
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
|
||||
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
|
||||
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
|
||||
if not candidate.strip().startswith("["):
|
||||
candidate = "[" + candidate + "]"
|
||||
|
||||
candidates = []
|
||||
try:
|
||||
# parse the candidate JSON into a dictionary
|
||||
candidates = json.loads(candidate)
|
||||
if not isinstance(candidates, list):
|
||||
candidates = [candidates]
|
||||
except json.JSONDecodeError:
|
||||
# Ignore invalid JSON silently
|
||||
continue
|
||||
|
||||
for candidate_dict in candidates:
|
||||
checked_candidate = check_and_sanitize_tool_call_candidate(candidate_dict, tool_names)
|
||||
if checked_candidate is not None:
|
||||
if start_pos is None:
|
||||
start_pos = match.start()
|
||||
matches.append(checked_candidate)
|
||||
|
||||
# last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags
|
||||
if len(matches) == 0:
|
||||
try:
|
||||
candidate = answer
|
||||
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
|
||||
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
|
||||
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
|
||||
if not candidate.strip().startswith("["):
|
||||
candidate = "[" + candidate + "]"
|
||||
# parse the candidate JSON into a dictionary
|
||||
candidates = json.loads(candidate)
|
||||
if not isinstance(candidates, list):
|
||||
candidates = [candidates]
|
||||
for candidate_dict in candidates:
|
||||
checked_candidate = check_and_sanitize_tool_call_candidate(candidate_dict, tool_names)
|
||||
if checked_candidate is not None:
|
||||
matches.append(checked_candidate)
|
||||
except json.JSONDecodeError:
|
||||
# Ignore invalid JSON silently
|
||||
pass
|
||||
|
||||
return _return(matches, start_pos)
|
||||
71
modules/tool_use.py
Normal file
71
modules/tool_use.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
import importlib.util
|
||||
import json
|
||||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.utils import natural_keys, sanitize_filename
|
||||
|
||||
|
||||
def get_available_tools():
|
||||
"""Return sorted list of tool script names from user_data/tools/*.py."""
|
||||
tools_dir = shared.user_data_dir / 'tools'
|
||||
tools_dir.mkdir(parents=True, exist_ok=True)
|
||||
return sorted((p.stem for p in tools_dir.glob('*.py')), key=natural_keys)
|
||||
|
||||
|
||||
def load_tools(selected_names):
|
||||
"""
|
||||
Import selected tool scripts and return their definitions and executors.
|
||||
Returns (tool_defs, executors) where:
|
||||
- tool_defs: list of OpenAI-format tool dicts
|
||||
- executors: dict mapping function_name -> execute callable
|
||||
"""
|
||||
tool_defs = []
|
||||
executors = {}
|
||||
for name in selected_names:
|
||||
name = sanitize_filename(name)
|
||||
if not name:
|
||||
continue
|
||||
|
||||
path = shared.user_data_dir / 'tools' / f'{name}.py'
|
||||
if not path.exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(f"tool_{name}", str(path))
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
except Exception:
|
||||
logger.exception(f'Failed to load tool script "{name}"')
|
||||
continue
|
||||
|
||||
tool_def = getattr(module, 'tool', None)
|
||||
execute_fn = getattr(module, 'execute', None)
|
||||
if tool_def is None or execute_fn is None:
|
||||
logger.warning(f'Tool "{name}" is missing a "tool" dict or "execute" function.')
|
||||
continue
|
||||
|
||||
func_name = tool_def.get('function', {}).get('name', name)
|
||||
if func_name in executors:
|
||||
logger.warning(f'Tool "{name}" declares function name "{func_name}" which conflicts with an already loaded tool. Skipping.')
|
||||
continue
|
||||
tool_defs.append(tool_def)
|
||||
executors[func_name] = execute_fn
|
||||
|
||||
return tool_defs, executors
|
||||
|
||||
|
||||
def execute_tool(func_name, arguments, executors):
|
||||
"""Execute a tool by function name. Returns result as a JSON string."""
|
||||
fn = executors.get(func_name)
|
||||
if fn is None:
|
||||
return json.dumps({"error": f"Unknown tool: {func_name}"})
|
||||
|
||||
try:
|
||||
if isinstance(arguments, str):
|
||||
arguments = json.loads(arguments)
|
||||
result = fn(arguments)
|
||||
return json.dumps(result) if not isinstance(result, str) else result
|
||||
except Exception as e:
|
||||
logger.exception(f'Tool "{func_name}" execution failed')
|
||||
return json.dumps({"error": str(e)})
|
||||
|
|
@ -310,6 +310,11 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
|||
|
||||
# == Input validation / processing ==
|
||||
yield "Preparing the input..."
|
||||
|
||||
if shared.args.loader == 'llama.cpp':
|
||||
yield "Error: LoRA training requires a model loaded with the Transformers loader. GGUF models are not supported for training."
|
||||
return
|
||||
|
||||
lora_file_path = clean_path(None, lora_name)
|
||||
if lora_file_path.strip() == '':
|
||||
yield "Missing or invalid LoRA file name input."
|
||||
|
|
|
|||
|
|
@ -65,14 +65,16 @@ class LogprobProcessor(LogitsProcessor):
|
|||
def __init__(self, logprobs=None):
|
||||
self.logprobs = logprobs
|
||||
self.token_alternatives = {}
|
||||
self.token_alternatives_history = []
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.logprobs is not None: # 0-5
|
||||
log_e_probabilities = F.log_softmax(logits, dim=1)
|
||||
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
|
||||
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs)
|
||||
top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]]
|
||||
top_probs = [float(x) for x in top_values[0]]
|
||||
self.token_alternatives = dict(zip(top_tokens, top_probs))
|
||||
self.token_alternatives_history.append(self.token_alternatives)
|
||||
|
||||
return logits
|
||||
|
||||
|
|
@ -134,8 +136,6 @@ def load_model_HF(model_name):
|
|||
shared.args.load_in_4bit,
|
||||
shared.args.disk,
|
||||
shared.args.cpu_memory is not None,
|
||||
shared.args.compress_pos_emb > 1,
|
||||
shared.args.alpha_value > 1,
|
||||
])
|
||||
|
||||
# Load the model without any special settings
|
||||
|
|
@ -198,11 +198,6 @@ def load_model_HF(model_name):
|
|||
if shared.args.disk:
|
||||
params['offload_folder'] = str(Path(shared.args.disk_cache_dir))
|
||||
|
||||
if shared.args.compress_pos_emb > 1:
|
||||
params['rope_scaling'] = {'type': 'linear', 'factor': shared.args.compress_pos_emb}
|
||||
elif shared.args.alpha_value > 1:
|
||||
params['rope_scaling'] = {'type': 'dynamic', 'factor': shared.args.alpha_value}
|
||||
|
||||
logger.info("TRANSFORMERS_PARAMS=")
|
||||
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params)
|
||||
print()
|
||||
|
|
|
|||
|
|
@ -120,58 +120,8 @@ else:
|
|||
|
||||
|
||||
def list_model_elements():
|
||||
elements = [
|
||||
'filter_by_loader',
|
||||
'loader',
|
||||
'cpu_memory',
|
||||
'gpu_layers',
|
||||
'fit_target',
|
||||
'cpu_moe',
|
||||
'threads',
|
||||
'threads_batch',
|
||||
'batch_size',
|
||||
'ubatch_size',
|
||||
'ctx_size',
|
||||
'cache_type',
|
||||
'tensor_split',
|
||||
'extra_flags',
|
||||
'streaming_llm',
|
||||
'gpu_split',
|
||||
'alpha_value',
|
||||
'rope_freq_base',
|
||||
'compress_pos_emb',
|
||||
'compute_dtype',
|
||||
'quant_type',
|
||||
'load_in_8bit',
|
||||
'load_in_4bit',
|
||||
'attn_implementation',
|
||||
'cpu',
|
||||
'disk',
|
||||
'row_split',
|
||||
'no_kv_offload',
|
||||
'no_mmap',
|
||||
'mlock',
|
||||
'numa',
|
||||
'parallel',
|
||||
'use_double_quant',
|
||||
'bf16',
|
||||
'enable_tp',
|
||||
'tp_backend',
|
||||
'cfg_cache',
|
||||
'no_use_fast',
|
||||
'model_draft',
|
||||
'draft_max',
|
||||
'gpu_layers_draft',
|
||||
'device_draft',
|
||||
'ctx_size_draft',
|
||||
'spec_type',
|
||||
'spec_ngram_size_n',
|
||||
'spec_ngram_size_m',
|
||||
'spec_ngram_min_hits',
|
||||
'mmproj',
|
||||
]
|
||||
|
||||
return elements
|
||||
from modules.loaders import list_model_elements
|
||||
return list_model_elements()
|
||||
|
||||
|
||||
def list_interface_input_elements():
|
||||
|
|
@ -249,6 +199,7 @@ def list_interface_input_elements():
|
|||
'unique_id',
|
||||
'textbox',
|
||||
'start_with',
|
||||
'selected_tools',
|
||||
'mode',
|
||||
'chat_style',
|
||||
'chat-instruct_command',
|
||||
|
|
@ -353,12 +304,16 @@ def save_settings(state, preset, extensions_list, show_controls, theme_state, ma
|
|||
if k in shared.settings and k not in exclude:
|
||||
output[k] = state[k]
|
||||
|
||||
output['preset'] = preset
|
||||
if preset:
|
||||
output['preset'] = preset
|
||||
output['prompt-notebook'] = state['prompt_menu-default'] if state['show_two_notebook_columns'] else state['prompt_menu-notebook']
|
||||
output['character'] = state['character_menu']
|
||||
if 'user_menu' in state and state['user_menu']:
|
||||
if state.get('character_menu'):
|
||||
output['character'] = state['character_menu']
|
||||
if state.get('user_menu'):
|
||||
output['user'] = state['user_menu']
|
||||
output['seed'] = int(output['seed'])
|
||||
output['custom_stopping_strings'] = output.get('custom_stopping_strings') or ''
|
||||
output['custom_token_bans'] = output.get('custom_token_bans') or ''
|
||||
output['show_controls'] = show_controls
|
||||
output['dark_theme'] = True if theme_state == 'dark' else False
|
||||
output.pop('instruction_template_str')
|
||||
|
|
@ -470,6 +425,7 @@ def setup_auto_save():
|
|||
'user_bio',
|
||||
'custom_system_message',
|
||||
'chat_template_str',
|
||||
'selected_tools',
|
||||
|
||||
# Parameters tab (ui_parameters.py) - Generation parameters
|
||||
'preset_menu',
|
||||
|
|
@ -520,7 +476,6 @@ def setup_auto_save():
|
|||
'skip_special_tokens',
|
||||
'stream',
|
||||
'static_cache',
|
||||
'truncation_length',
|
||||
'seed',
|
||||
'sampler_priority',
|
||||
'custom_stopping_strings',
|
||||
|
|
|
|||
|
|
@ -28,7 +28,8 @@ def create_ui():
|
|||
shared.gradio['branch_chat'] = gr.Button('Branch', elem_classes=['refresh-button', 'refresh-button-medium'], elem_id='Branch', interactive=not mu)
|
||||
shared.gradio['rename_chat'] = gr.Button('Rename', elem_classes=['refresh-button', 'refresh-button-medium'], interactive=not mu)
|
||||
shared.gradio['delete_chat'] = gr.Button('🗑️', visible=False, elem_classes='refresh-button', interactive=not mu, elem_id='delete_chat')
|
||||
shared.gradio['Start new chat'] = gr.Button('New chat', elem_classes=['refresh-button', 'refresh-button-medium', 'focus-on-chat-input'])
|
||||
shared.gradio['Start new chat'] = gr.Button('New chat', elem_classes=['refresh-button', 'refresh-button-medium', 'focus-on-chat-input'], elem_id='new-chat-btn')
|
||||
shared.gradio['Start incognito chat'] = gr.Button('Incognito chat', visible=False, elem_id='incognito-chat-btn')
|
||||
shared.gradio['branch_index'] = gr.Number(value=-1, precision=0, visible=False, elem_id="Branch-index", interactive=True)
|
||||
|
||||
shared.gradio['search_chat'] = gr.Textbox(placeholder='Search chats...', max_lines=1, elem_id='search_chat')
|
||||
|
|
@ -91,6 +92,21 @@ def create_ui():
|
|||
|
||||
gr.HTML("<div class='sidebar-vertical-separator'></div>")
|
||||
|
||||
from modules.tool_use import get_available_tools
|
||||
shared.gradio['selected_tools'] = gr.CheckboxGroup(choices=get_available_tools(), value=shared.settings.get('selected_tools', []), label='Tools', info='Functions the model can call during generation.', elem_id='tools-group')
|
||||
shared.gradio['tools_refresh'] = gr.Button('Refresh list', elem_id='tools-refresh-btn', visible=False)
|
||||
shared.gradio['tools_refresh'].click(fn=lambda: gr.update(choices=get_available_tools()), inputs=[], outputs=[shared.gradio['selected_tools']])
|
||||
|
||||
def sync_web_tools(selected):
|
||||
if 'web_search' in selected and 'fetch_webpage' not in selected and 'fetch_webpage' in get_available_tools():
|
||||
selected.append('fetch_webpage')
|
||||
|
||||
return gr.update(value=selected)
|
||||
|
||||
shared.gradio['selected_tools'].change(fn=sync_web_tools, inputs=[shared.gradio['selected_tools']], outputs=[shared.gradio['selected_tools']], show_progress=False)
|
||||
|
||||
gr.HTML("<div class='sidebar-vertical-separator'></div>")
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['mode'] = gr.Radio(choices=['instruct', 'chat-instruct', 'chat'], value=None, label='Mode', info='In instruct and chat-instruct modes, the template under Parameters > Instruction template is used.', elem_id='chat-mode')
|
||||
|
||||
|
|
@ -275,6 +291,10 @@ def create_event_handlers():
|
|||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
chat.handle_start_new_chat_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False)
|
||||
|
||||
shared.gradio['Start incognito chat'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
chat.handle_start_incognito_chat_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False)
|
||||
|
||||
shared.gradio['delete_chat-confirm'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
chat.handle_delete_chat_confirm_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False)
|
||||
|
|
|
|||
|
|
@ -728,6 +728,8 @@ def generate_prompt_variation(state):
|
|||
variation = variation.rsplit("</think>", 1)[1]
|
||||
elif "<|start|>assistant<|channel|>final<|message|>" in variation:
|
||||
variation = variation.rsplit("<|start|>assistant<|channel|>final<|message|>", 1)[1]
|
||||
elif "<|channel|>final<|message|>" in variation:
|
||||
variation = variation.rsplit("<|channel|>final<|message|>", 1)[1]
|
||||
elif "</seed:think>" in variation:
|
||||
variation = variation.rsplit("</seed:think>", 1)[1]
|
||||
|
||||
|
|
|
|||
|
|
@ -42,11 +42,11 @@ def create_ui():
|
|||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['gpu_layers'] = gr.Slider(label="gpu-layers", minimum=-1, maximum=get_initial_gpu_layers_max(), step=1, value=shared.args.gpu_layers, info='Number of layers to offload to the GPU. -1 = auto.')
|
||||
shared.gradio['ctx_size'] = gr.Slider(label='ctx-size', minimum=0, maximum=131072, step=256, value=shared.args.ctx_size, info='Context length. llama.cpp: 0 = auto if gpu-layers is also -1. Common values: 4096, 8192, 16384, 32768, 65536, 131072.')
|
||||
shared.gradio['ctx_size'] = gr.Slider(label='ctx-size', minimum=0, maximum=1048576, step=1024, value=shared.args.ctx_size, info='Context length. 0 = auto for llama.cpp (requires gpu-layers=-1), 8192 for other loaders. Common values: 4096, 8192, 16384, 32768, 65536, 131072.')
|
||||
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
|
||||
shared.gradio['attn_implementation'] = gr.Dropdown(label="attn-implementation", choices=['sdpa', 'eager', 'flash_attention_2'], value=shared.args.attn_implementation, info='Attention implementation.')
|
||||
shared.gradio['cache_type'] = gr.Dropdown(label="cache-type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q7', 'q6', 'q5', 'q4', 'q3', 'q2'], value=shared.args.cache_type, allow_custom_value=True, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8. For ExLlamaV3, you can type custom combinations for separate k/v bits (e.g. q4_q8).')
|
||||
shared.gradio['fit_target'] = gr.Textbox(label='fit-target', value=shared.args.fit_target, info='Target VRAM margin per device for auto GPU layers (MiB). Comma-separated list for multiple devices. Default: 1024.')
|
||||
shared.gradio['fit_target'] = gr.Textbox(label='fit-target', value=shared.args.fit_target, info='Target VRAM margin per device for auto GPU layers (MiB). Comma-separated list for multiple devices.')
|
||||
shared.gradio['tp_backend'] = gr.Dropdown(label="tp-backend", choices=['native', 'nccl'], value=shared.args.tp_backend, info='The backend for tensor parallelism.')
|
||||
|
||||
with gr.Column():
|
||||
|
|
@ -100,9 +100,6 @@ def create_ui():
|
|||
shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')
|
||||
shared.gradio['extra_flags'] = gr.Textbox(label='extra-flags', info='Additional flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"', value=shared.args.extra_flags)
|
||||
shared.gradio['cpu_memory'] = gr.Number(label="Maximum CPU memory in GiB. Use this for CPU offloading.", value=shared.args.cpu_memory)
|
||||
shared.gradio['alpha_value'] = gr.Number(label='alpha_value', value=shared.args.alpha_value, precision=2, info='Positional embeddings alpha factor for NTK RoPE scaling. Recommended values (NTKv1): 1.75 for 1.5x context, 2.5 for 2x context. Use either this or compress_pos_emb, not both.')
|
||||
shared.gradio['rope_freq_base'] = gr.Number(label='rope_freq_base', value=shared.args.rope_freq_base, precision=0, info='Positional embeddings frequency base for NTK RoPE scaling. Related to alpha_value by rope_freq_base = 10000 * alpha_value ^ (64 / 63). 0 = from model.')
|
||||
shared.gradio['compress_pos_emb'] = gr.Number(label='compress_pos_emb', value=shared.args.compress_pos_emb, precision=2, info='Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.')
|
||||
shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype, info='Used by load-in-4bit.')
|
||||
shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type, info='Used by load-in-4bit.')
|
||||
|
||||
|
|
@ -388,7 +385,11 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
|
|||
def update_truncation_length(current_length, state):
|
||||
if 'loader' in state:
|
||||
if state['loader'].lower().startswith('exllama') or state['loader'] == 'llama.cpp':
|
||||
return state['ctx_size']
|
||||
if state['ctx_size'] > 0:
|
||||
return state['ctx_size']
|
||||
|
||||
# ctx_size == 0 means auto: use the actual value from the server
|
||||
return shared.settings['truncation_length']
|
||||
|
||||
return current_length
|
||||
|
||||
|
|
|
|||
|
|
@ -37,10 +37,10 @@ def create_ui():
|
|||
shared.gradio['dynamic_temperature'] = gr.Checkbox(value=shared.settings['dynamic_temperature'], label='dynamic_temperature')
|
||||
|
||||
gr.Markdown('## Curve cutoff')
|
||||
shared.gradio['min_p'] = gr.Slider(0.0, 1.0, value=shared.settings['min_p'], step=0.01, label='min_p')
|
||||
shared.gradio['top_n_sigma'] = gr.Slider(0.0, 5.0, value=shared.settings['top_n_sigma'], step=0.01, label='top_n_sigma')
|
||||
shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=shared.settings['top_p'], step=0.01, label='top_p')
|
||||
shared.gradio['top_k'] = gr.Slider(0, 200, value=shared.settings['top_k'], step=1, label='top_k')
|
||||
shared.gradio['min_p'] = gr.Slider(0.0, 1.0, value=shared.settings['min_p'], step=0.01, label='min_p')
|
||||
shared.gradio['top_n_sigma'] = gr.Slider(0.0, 5.0, value=shared.settings['top_n_sigma'], step=0.01, label='top_n_sigma')
|
||||
shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=shared.settings['typical_p'], step=0.01, label='typical_p')
|
||||
shared.gradio['xtc_threshold'] = gr.Slider(0, 0.5, value=shared.settings['xtc_threshold'], step=0.01, label='xtc_threshold', info='If 2 or more tokens have probability above this threshold, consider removing all but the last one.')
|
||||
shared.gradio['xtc_probability'] = gr.Slider(0, 1, value=shared.settings['xtc_probability'], step=0.01, label='xtc_probability', info='Probability that the removal will actually happen. 0 disables the sampler. 1 makes it always happen.')
|
||||
|
|
@ -73,7 +73,7 @@ def create_ui():
|
|||
gr.Markdown('## Other options')
|
||||
shared.gradio['do_sample'] = gr.Checkbox(value=shared.settings['do_sample'], label='do_sample')
|
||||
shared.gradio['temperature_last'] = gr.Checkbox(value=shared.settings['temperature_last'], label='temperature_last', info='Moves temperature/dynamic temperature/quadratic sampling to the end of the sampler stack, ignoring their positions in "Sampler priority".')
|
||||
shared.gradio['sampler_priority'] = gr.Textbox(value=shared.settings['sampler_priority'], lines=10, label='Sampler priority', info='Parameter names separated by new lines or commas.', elem_classes=['add_scrollbar'])
|
||||
shared.gradio['sampler_priority'] = gr.DragDrop(value=shared.settings['sampler_priority'], label='Sampler priority', info='Parameter names separated by new lines or commas.', elem_classes=['add_scrollbar'])
|
||||
shared.gradio['dry_sequence_breakers'] = gr.Textbox(value=shared.settings['dry_sequence_breakers'], label='dry_sequence_breakers', info='Tokens across which sequence matching is not continued. Specified as a comma-separated list of quoted strings.')
|
||||
|
||||
with gr.Column():
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
import concurrent.futures
|
||||
import html
|
||||
import ipaddress
|
||||
import random
|
||||
import re
|
||||
import urllib.request
|
||||
import socket
|
||||
from concurrent.futures import as_completed
|
||||
from datetime import datetime
|
||||
from urllib.parse import quote_plus
|
||||
from urllib.parse import parse_qs, quote_plus, urljoin, urlparse
|
||||
|
||||
import requests
|
||||
|
||||
|
|
@ -13,34 +14,60 @@ from modules import shared
|
|||
from modules.logging_colors import logger
|
||||
|
||||
|
||||
def _validate_url(url):
|
||||
"""Validate that a URL is safe to fetch (not targeting private/internal networks)."""
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ('http', 'https'):
|
||||
raise ValueError(f"Unsupported URL scheme: {parsed.scheme}")
|
||||
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
raise ValueError("No hostname in URL")
|
||||
|
||||
# Resolve hostname and check all returned addresses
|
||||
try:
|
||||
for family, _, _, _, sockaddr in socket.getaddrinfo(hostname, None):
|
||||
ip = ipaddress.ip_address(sockaddr[0])
|
||||
if not ip.is_global:
|
||||
raise ValueError(f"Access to non-public address {ip} is blocked")
|
||||
except socket.gaierror:
|
||||
raise ValueError(f"Could not resolve hostname: {hostname}")
|
||||
|
||||
|
||||
def get_current_timestamp():
|
||||
"""Returns the current time in 24-hour format"""
|
||||
return datetime.now().strftime('%b %d, %Y %H:%M')
|
||||
|
||||
|
||||
def download_web_page(url, timeout=10):
|
||||
def download_web_page(url, timeout=10, include_links=False):
|
||||
"""
|
||||
Download a web page and convert its HTML content to structured Markdown text.
|
||||
Download a web page and extract its main content as Markdown text.
|
||||
"""
|
||||
import html2text
|
||||
import trafilatura
|
||||
|
||||
try:
|
||||
_validate_url(url)
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
||||
}
|
||||
response = requests.get(url, headers=headers, timeout=timeout)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
max_redirects = 5
|
||||
for _ in range(max_redirects):
|
||||
response = requests.get(url, headers=headers, timeout=timeout, allow_redirects=False)
|
||||
if response.is_redirect and 'Location' in response.headers:
|
||||
url = urljoin(url, response.headers['Location'])
|
||||
_validate_url(url)
|
||||
else:
|
||||
break
|
||||
|
||||
# Initialize the HTML to Markdown converter
|
||||
h = html2text.HTML2Text()
|
||||
h.body_width = 0
|
||||
h.ignore_images = True
|
||||
h.ignore_links = True
|
||||
response.raise_for_status()
|
||||
|
||||
# Convert the HTML to Markdown
|
||||
markdown_text = h.handle(response.text)
|
||||
|
||||
return markdown_text
|
||||
result = trafilatura.extract(
|
||||
response.text,
|
||||
include_links=include_links,
|
||||
output_format='markdown',
|
||||
url=url
|
||||
)
|
||||
return result or ""
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Error downloading {url}: {e}")
|
||||
return ""
|
||||
|
|
@ -49,8 +76,8 @@ def download_web_page(url, timeout=10):
|
|||
return ""
|
||||
|
||||
|
||||
def perform_web_search(query, num_pages=3, max_workers=5, timeout=10):
|
||||
"""Perform web search and return results with content"""
|
||||
def perform_web_search(query, num_pages=3, max_workers=5, timeout=10, fetch_content=True):
|
||||
"""Perform web search and return results, optionally with page content"""
|
||||
try:
|
||||
search_url = f"https://html.duckduckgo.com/html/?q={quote_plus(query)}"
|
||||
|
||||
|
|
@ -59,25 +86,41 @@ def perform_web_search(query, num_pages=3, max_workers=5, timeout=10):
|
|||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36"
|
||||
]
|
||||
|
||||
response_text = ""
|
||||
req = urllib.request.Request(search_url, headers={'User-Agent': random.choice(agents)})
|
||||
with urllib.request.urlopen(req, timeout=timeout) as response:
|
||||
response_text = response.read().decode('utf-8')
|
||||
response = requests.get(search_url, headers={'User-Agent': random.choice(agents)}, timeout=timeout)
|
||||
response.raise_for_status()
|
||||
response_text = response.text
|
||||
|
||||
# Extract results with regex
|
||||
titles = re.findall(r'<a[^>]*class="[^"]*result__a[^"]*"[^>]*>(.*?)</a>', response_text, re.DOTALL)
|
||||
urls = re.findall(r'<a[^>]*class="[^"]*result__url[^"]*"[^>]*>(.*?)</a>', response_text, re.DOTALL)
|
||||
# Extract results - title and URL come from the same <a class="result__a"> element
|
||||
result_links = re.findall(r'<a[^>]*class="[^"]*result__a[^"]*"[^>]*>(.*?)</a>', response_text, re.DOTALL)
|
||||
result_tags = re.findall(r'<a([^>]*class="[^"]*result__a[^"]*"[^>]*)>', response_text, re.DOTALL)
|
||||
|
||||
# Prepare download tasks
|
||||
download_tasks = []
|
||||
for i in range(min(len(titles), len(urls), num_pages)):
|
||||
url = f"https://{urls[i].strip()}"
|
||||
title = re.sub(r'<[^>]+>', '', titles[i]).strip()
|
||||
title = html.unescape(title)
|
||||
download_tasks.append((url, title, i))
|
||||
for i, (tag_attrs, raw_title) in enumerate(zip(result_tags, result_links)):
|
||||
if num_pages is not None and i >= num_pages:
|
||||
break
|
||||
# Extract href and resolve the actual URL from DuckDuckGo's redirect link
|
||||
href_match = re.search(r'href="([^"]*)"', tag_attrs)
|
||||
if not href_match:
|
||||
continue
|
||||
uddg = parse_qs(urlparse(html.unescape(href_match.group(1))).query).get('uddg', [''])[0]
|
||||
if not uddg:
|
||||
continue
|
||||
title = html.unescape(re.sub(r'<[^>]+>', '', raw_title).strip())
|
||||
download_tasks.append((uddg, title, len(download_tasks)))
|
||||
|
||||
search_results = [None] * len(download_tasks) # Pre-allocate to maintain order
|
||||
|
||||
if not fetch_content:
|
||||
for url, title, index in download_tasks:
|
||||
search_results[index] = {
|
||||
'title': title,
|
||||
'url': url,
|
||||
'content': ''
|
||||
}
|
||||
|
||||
return search_results
|
||||
|
||||
# Download pages in parallel
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all download tasks
|
||||
|
|
|
|||
42
one_click.py
42
one_click.py
|
|
@ -91,7 +91,7 @@ def get_gpu_choice():
|
|||
"What is your GPU?",
|
||||
{
|
||||
'A': 'NVIDIA',
|
||||
'B': 'AMD - Linux/macOS only, requires ROCm 6.4',
|
||||
'B': 'AMD - Linux only, ROCm 7.2',
|
||||
'C': 'Apple M Series',
|
||||
'D': 'Intel Arc (beta)',
|
||||
'N': 'CPU mode'
|
||||
|
|
@ -116,14 +116,12 @@ def get_pytorch_install_command(gpu_choice):
|
|||
if gpu_choice == "NVIDIA_CUDA128":
|
||||
return base_cmd + "--index-url https://download.pytorch.org/whl/cu128" + pypi_fallback
|
||||
elif gpu_choice == "AMD":
|
||||
return base_cmd + "--index-url https://download.pytorch.org/whl/rocm6.4" + pypi_fallback
|
||||
py_tag = f"cp{PYTHON_VERSION.replace('.', '')}"
|
||||
return f"python -m pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-{TORCH_VERSION}%2Brocm7.2.0.lw.git7e1940d4-{py_tag}-{py_tag}-linux_x86_64.whl"
|
||||
elif gpu_choice in ["APPLE", "NONE"]:
|
||||
return base_cmd + "--index-url https://download.pytorch.org/whl/cpu" + pypi_fallback
|
||||
elif gpu_choice == "INTEL":
|
||||
if is_linux():
|
||||
return "python -m pip install torch==2.1.0a0 intel-extension-for-pytorch==2.1.10+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
|
||||
else:
|
||||
return "python -m pip install torch==2.1.0a0 intel-extension-for-pytorch==2.1.10 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
|
||||
return base_cmd + "--index-url https://download.pytorch.org/whl/xpu"
|
||||
else:
|
||||
return base_cmd
|
||||
|
||||
|
|
@ -136,12 +134,12 @@ def get_pytorch_update_command(gpu_choice):
|
|||
if gpu_choice == "NVIDIA_CUDA128":
|
||||
return f"{base_cmd}--index-url https://download.pytorch.org/whl/cu128" + pypi_fallback
|
||||
elif gpu_choice == "AMD":
|
||||
return f"{base_cmd}--index-url https://download.pytorch.org/whl/rocm6.4" + pypi_fallback
|
||||
py_tag = f"cp{PYTHON_VERSION.replace('.', '')}"
|
||||
return f"python -m pip install --upgrade https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-{TORCH_VERSION}%2Brocm7.2.0.lw.git7e1940d4-{py_tag}-{py_tag}-linux_x86_64.whl"
|
||||
elif gpu_choice in ["APPLE", "NONE"]:
|
||||
return f"{base_cmd}--index-url https://download.pytorch.org/whl/cpu" + pypi_fallback
|
||||
elif gpu_choice == "INTEL":
|
||||
intel_extension = "intel-extension-for-pytorch==2.1.10+xpu" if is_linux() else "intel-extension-for-pytorch==2.1.10"
|
||||
return f"{base_cmd}{intel_extension} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
|
||||
return f"{base_cmd}--index-url https://download.pytorch.org/whl/xpu"
|
||||
else:
|
||||
return base_cmd
|
||||
|
||||
|
|
@ -196,6 +194,8 @@ def run_cmd(cmd, assert_success=False, environment=False, capture_output=False,
|
|||
if environment:
|
||||
if is_windows():
|
||||
conda_bat_path = os.path.join(script_dir, "installer_files", "conda", "condabin", "conda.bat")
|
||||
python_path = os.path.join(conda_env_path, "python.exe")
|
||||
cmd = cmd.replace("python ", f'"{python_path}" ')
|
||||
cmd = f'"{conda_bat_path}" activate "{conda_env_path}" >nul && {cmd}'
|
||||
else:
|
||||
conda_sh_path = os.path.join(script_dir, "installer_files", "conda", "etc", "profile.d", "conda.sh")
|
||||
|
|
@ -270,7 +270,7 @@ def update_pytorch_and_python():
|
|||
|
||||
|
||||
def clean_outdated_pytorch_cuda_dependencies():
|
||||
patterns = ["cu121", "cu122", "torch2.4", "torch2.6", "torch2.7", "torchvision", "torchaudio"]
|
||||
patterns = ["cu121", "cu122", "rocm6", "torch2.4", "torch2.6", "torch2.7", "torchvision", "torchaudio"]
|
||||
result = run_cmd("python -m pip list --format=freeze", capture_output=True, environment=True)
|
||||
matching_packages = []
|
||||
|
||||
|
|
@ -316,13 +316,6 @@ def install_webui():
|
|||
install_pytorch = get_pytorch_install_command(gpu_choice)
|
||||
run_cmd(f"conda install -y ninja git && {install_pytorch}", assert_success=True, environment=True)
|
||||
|
||||
if gpu_choice == "INTEL":
|
||||
# Install oneAPI dependencies via conda
|
||||
print_big_message("Installing Intel oneAPI runtime libraries.")
|
||||
run_cmd("conda install -y -c https://software.repos.intel.com/python/conda/ -c conda-forge dpcpp-cpp-rt=2024.0 mkl-dpcpp=2024.0", environment=True)
|
||||
# Install libuv required by Intel-patched torch
|
||||
run_cmd("conda install -y libuv", environment=True)
|
||||
|
||||
# Install the webui requirements
|
||||
update_requirements(initial_installation=True, pull=False)
|
||||
|
||||
|
|
@ -365,8 +358,10 @@ def update_requirements(initial_installation=False, pull=True):
|
|||
|
||||
current_commit = get_current_commit()
|
||||
wheels_changed = not os.path.exists(state_file)
|
||||
installed_wheels = set()
|
||||
if not wheels_changed:
|
||||
state = load_state()
|
||||
installed_wheels = set(state.get('installed_wheels', []))
|
||||
if 'wheels_changed' in state or state.get('last_installed_commit') != current_commit:
|
||||
wheels_changed = True
|
||||
|
||||
|
|
@ -431,9 +426,17 @@ def update_requirements(initial_installation=False, pull=True):
|
|||
|
||||
# Prepare the requirements file
|
||||
textgen_requirements = open(requirements_file).read().splitlines()
|
||||
all_whl_lines = [line.strip() for line in textgen_requirements if '.whl' in line]
|
||||
|
||||
if not initial_installation and not wheels_changed:
|
||||
textgen_requirements = [line for line in textgen_requirements if '.whl' not in line]
|
||||
if not initial_installation:
|
||||
if installed_wheels:
|
||||
# Per-wheel comparison: only re-download wheels that changed
|
||||
textgen_requirements = [
|
||||
line for line in textgen_requirements
|
||||
if '.whl' not in line or line.strip() not in installed_wheels
|
||||
]
|
||||
elif not wheels_changed:
|
||||
textgen_requirements = [line for line in textgen_requirements if '.whl' not in line]
|
||||
|
||||
with open('temp_requirements.txt', 'w') as file:
|
||||
file.write('\n'.join(textgen_requirements))
|
||||
|
|
@ -452,6 +455,7 @@ def update_requirements(initial_installation=False, pull=True):
|
|||
# Save state after successful installation
|
||||
state = load_state()
|
||||
state['last_installed_commit'] = current_commit
|
||||
state['installed_wheels'] = all_whl_lines
|
||||
state.pop('wheels_changed', None)
|
||||
save_state(state)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,11 +2,10 @@ accelerate==1.12.*
|
|||
audioop-lts<1.0; python_version >= "3.13"
|
||||
bitsandbytes==0.49.*
|
||||
datasets
|
||||
diffusers==0.36.*
|
||||
diffusers==0.37.*
|
||||
einops
|
||||
fastapi==0.112.4
|
||||
flash-linear-attention==0.4.*
|
||||
html2text==2025.4.15
|
||||
huggingface-hub==1.5.*
|
||||
jinja2==3.1.6
|
||||
markdown
|
||||
|
|
@ -25,14 +24,15 @@ scipy
|
|||
sentencepiece
|
||||
tensorboard
|
||||
torchao==0.15.*
|
||||
trafilatura==2.0.0
|
||||
transformers==5.3.*
|
||||
triton-windows==3.5.1.post24; platform_system == "Windows"
|
||||
tqdm
|
||||
wandb
|
||||
|
||||
# Gradio
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
|
||||
|
||||
# API
|
||||
flask_cloudflared==0.0.15
|
||||
|
|
@ -40,9 +40,9 @@ sse-starlette==1.6.5
|
|||
tiktoken
|
||||
|
||||
# CUDA wheels
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||
https://github.com/turboderp-org/exllamav3/releases/download/v0.0.23/exllamav3-0.0.23+cu128.torch2.9.0-cp313-cp313-win_amd64.whl; platform_system == "Windows" and python_version == "3.13"
|
||||
https://github.com/turboderp-org/exllamav3/releases/download/v0.0.23/exllamav3-0.0.23+cu128.torch2.9.0-cp313-cp313-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.13"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||
https://github.com/turboderp-org/exllamav3/releases/download/v0.0.25/exllamav3-0.0.25+cu128.torch2.9.0-cp313-cp313-win_amd64.whl; platform_system == "Windows" and python_version == "3.13"
|
||||
https://github.com/turboderp-org/exllamav3/releases/download/v0.0.25/exllamav3-0.0.25+cu128.torch2.9.0-cp313-cp313-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.13"
|
||||
https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp313-cp313-win_amd64.whl; platform_system == "Windows" and python_version == "3.13"
|
||||
https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp313-cp313-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.13"
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
accelerate==1.12.*
|
||||
audioop-lts<1.0; python_version >= "3.13"
|
||||
datasets
|
||||
diffusers==0.36.*
|
||||
diffusers==0.37.*
|
||||
einops
|
||||
fastapi==0.112.4
|
||||
html2text==2025.4.15
|
||||
huggingface-hub==1.5.*
|
||||
jinja2==3.1.6
|
||||
markdown
|
||||
|
|
@ -25,11 +24,12 @@ tensorboard
|
|||
torchao==0.15.*
|
||||
transformers==5.3.*
|
||||
tqdm
|
||||
trafilatura==2.0.0
|
||||
wandb
|
||||
|
||||
# Gradio
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
|
||||
|
||||
# API
|
||||
flask_cloudflared==0.0.15
|
||||
|
|
@ -37,5 +37,5 @@ sse-starlette==1.6.5
|
|||
tiktoken
|
||||
|
||||
# AMD wheels
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+rocm6.4-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+rocm6.4-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+rocm7.2-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+rocm7.2-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
accelerate==1.12.*
|
||||
audioop-lts<1.0; python_version >= "3.13"
|
||||
datasets
|
||||
diffusers==0.36.*
|
||||
diffusers==0.37.*
|
||||
einops
|
||||
fastapi==0.112.4
|
||||
html2text==2025.4.15
|
||||
huggingface-hub==1.5.*
|
||||
jinja2==3.1.6
|
||||
markdown
|
||||
|
|
@ -25,11 +24,12 @@ tensorboard
|
|||
torchao==0.15.*
|
||||
transformers==5.3.*
|
||||
tqdm
|
||||
trafilatura==2.0.0
|
||||
wandb
|
||||
|
||||
# Gradio
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
|
||||
|
||||
# API
|
||||
flask_cloudflared==0.0.15
|
||||
|
|
@ -37,4 +37,4 @@ sse-starlette==1.6.5
|
|||
tiktoken
|
||||
|
||||
# Mac wheels
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0-py3-none-macosx_13_0_x86_64.whl; platform_system == "Darwin"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0-py3-none-macosx_13_0_x86_64.whl; platform_system == "Darwin"
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
accelerate==1.12.*
|
||||
audioop-lts<1.0; python_version >= "3.13"
|
||||
datasets
|
||||
diffusers==0.36.*
|
||||
diffusers==0.37.*
|
||||
einops
|
||||
fastapi==0.112.4
|
||||
html2text==2025.4.15
|
||||
huggingface-hub==1.5.*
|
||||
jinja2==3.1.6
|
||||
markdown
|
||||
|
|
@ -25,11 +24,12 @@ tensorboard
|
|||
torchao==0.15.*
|
||||
transformers==5.3.*
|
||||
tqdm
|
||||
trafilatura==2.0.0
|
||||
wandb
|
||||
|
||||
# Gradio
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
|
||||
|
||||
# API
|
||||
flask_cloudflared==0.0.15
|
||||
|
|
@ -37,4 +37,4 @@ sse-starlette==1.6.5
|
|||
tiktoken
|
||||
|
||||
# Mac wheels
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0-py3-none-macosx_13_0_arm64.whl; platform_system == "Darwin"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0-py3-none-macosx_13_0_arm64.whl; platform_system == "Darwin"
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
accelerate==1.12.*
|
||||
audioop-lts<1.0; python_version >= "3.13"
|
||||
datasets
|
||||
diffusers==0.36.*
|
||||
diffusers==0.37.*
|
||||
einops
|
||||
fastapi==0.112.4
|
||||
html2text==2025.4.15
|
||||
huggingface-hub==1.5.*
|
||||
jinja2==3.1.6
|
||||
markdown
|
||||
|
|
@ -25,11 +24,12 @@ tensorboard
|
|||
torchao==0.15.*
|
||||
transformers==5.3.*
|
||||
tqdm
|
||||
trafilatura==2.0.0
|
||||
wandb
|
||||
|
||||
# Gradio
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
|
||||
|
||||
# API
|
||||
flask_cloudflared==0.0.15
|
||||
|
|
@ -37,5 +37,5 @@ sse-starlette==1.6.5
|
|||
tiktoken
|
||||
|
||||
# llama.cpp (CPU only)
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cpu-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cpu-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cpu-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cpu-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
accelerate==1.12.*
|
||||
audioop-lts<1.0; python_version >= "3.13"
|
||||
datasets
|
||||
diffusers==0.36.*
|
||||
diffusers==0.37.*
|
||||
einops
|
||||
fastapi==0.112.4
|
||||
html2text==2025.4.15
|
||||
huggingface-hub==1.5.*
|
||||
jinja2==3.1.6
|
||||
markdown
|
||||
|
|
@ -25,11 +24,12 @@ tensorboard
|
|||
torchao==0.15.*
|
||||
transformers==5.3.*
|
||||
tqdm
|
||||
trafilatura==2.0.0
|
||||
wandb
|
||||
|
||||
# Gradio
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
|
||||
|
||||
# API
|
||||
flask_cloudflared==0.0.15
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
audioop-lts<1.0; python_version >= "3.13"
|
||||
fastapi==0.112.4
|
||||
html2text==2025.4.15
|
||||
huggingface-hub==1.5.*
|
||||
jinja2==3.1.6
|
||||
markdown
|
||||
|
|
@ -11,11 +10,12 @@ python-docx==1.1.2
|
|||
pyyaml
|
||||
requests
|
||||
rich
|
||||
trafilatura==2.0.0
|
||||
tqdm
|
||||
|
||||
# Gradio
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
|
||||
|
||||
# API
|
||||
flask_cloudflared==0.0.15
|
||||
|
|
@ -23,5 +23,5 @@ sse-starlette==1.6.5
|
|||
tiktoken
|
||||
|
||||
# CUDA wheels
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
audioop-lts<1.0; python_version >= "3.13"
|
||||
fastapi==0.112.4
|
||||
html2text==2025.4.15
|
||||
huggingface-hub==1.5.*
|
||||
jinja2==3.1.6
|
||||
markdown
|
||||
|
|
@ -11,11 +10,12 @@ python-docx==1.1.2
|
|||
pyyaml
|
||||
requests
|
||||
rich
|
||||
trafilatura==2.0.0
|
||||
tqdm
|
||||
|
||||
# Gradio
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
|
||||
|
||||
# API
|
||||
flask_cloudflared==0.0.15
|
||||
|
|
@ -23,5 +23,5 @@ sse-starlette==1.6.5
|
|||
tiktoken
|
||||
|
||||
# AMD wheels
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+rocm6.4-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+rocm6.4-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+rocm7.2-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+rocm7.2-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
audioop-lts<1.0; python_version >= "3.13"
|
||||
fastapi==0.112.4
|
||||
html2text==2025.4.15
|
||||
huggingface-hub==1.5.*
|
||||
jinja2==3.1.6
|
||||
markdown
|
||||
|
|
@ -11,11 +10,12 @@ python-docx==1.1.2
|
|||
pyyaml
|
||||
requests
|
||||
rich
|
||||
trafilatura==2.0.0
|
||||
tqdm
|
||||
|
||||
# Gradio
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
|
||||
|
||||
# API
|
||||
flask_cloudflared==0.0.15
|
||||
|
|
@ -23,4 +23,4 @@ sse-starlette==1.6.5
|
|||
tiktoken
|
||||
|
||||
# Mac wheels
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0-py3-none-macosx_13_0_x86_64.whl; platform_system == "Darwin"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0-py3-none-macosx_13_0_x86_64.whl; platform_system == "Darwin"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
audioop-lts<1.0; python_version >= "3.13"
|
||||
fastapi==0.112.4
|
||||
html2text==2025.4.15
|
||||
huggingface-hub==1.5.*
|
||||
jinja2==3.1.6
|
||||
markdown
|
||||
|
|
@ -11,11 +10,12 @@ python-docx==1.1.2
|
|||
pyyaml
|
||||
requests
|
||||
rich
|
||||
trafilatura==2.0.0
|
||||
tqdm
|
||||
|
||||
# Gradio
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
|
||||
|
||||
# API
|
||||
flask_cloudflared==0.0.15
|
||||
|
|
@ -23,4 +23,4 @@ sse-starlette==1.6.5
|
|||
tiktoken
|
||||
|
||||
# Mac wheels
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0-py3-none-macosx_13_0_arm64.whl; platform_system == "Darwin"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0-py3-none-macosx_13_0_arm64.whl; platform_system == "Darwin"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
audioop-lts<1.0; python_version >= "3.13"
|
||||
fastapi==0.112.4
|
||||
html2text==2025.4.15
|
||||
huggingface-hub==1.5.*
|
||||
jinja2==3.1.6
|
||||
markdown
|
||||
|
|
@ -11,11 +10,12 @@ python-docx==1.1.2
|
|||
pyyaml
|
||||
requests
|
||||
rich
|
||||
trafilatura==2.0.0
|
||||
tqdm
|
||||
|
||||
# Gradio
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
|
||||
|
||||
# API
|
||||
flask_cloudflared==0.0.15
|
||||
|
|
@ -23,5 +23,5 @@ sse-starlette==1.6.5
|
|||
tiktoken
|
||||
|
||||
# llama.cpp (CPU only)
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cpu-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cpu-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cpu-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cpu-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
audioop-lts<1.0; python_version >= "3.13"
|
||||
fastapi==0.112.4
|
||||
html2text==2025.4.15
|
||||
huggingface-hub==1.5.*
|
||||
jinja2==3.1.6
|
||||
markdown
|
||||
|
|
@ -11,11 +10,12 @@ python-docx==1.1.2
|
|||
pyyaml
|
||||
requests
|
||||
rich
|
||||
trafilatura==2.0.0
|
||||
tqdm
|
||||
|
||||
# Gradio
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
|
||||
|
||||
# API
|
||||
flask_cloudflared==0.0.15
|
||||
|
|
@ -23,5 +23,5 @@ sse-starlette==1.6.5
|
|||
tiktoken
|
||||
|
||||
# CUDA wheels
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu131-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu131-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu131-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu131-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
audioop-lts<1.0; python_version >= "3.13"
|
||||
fastapi==0.112.4
|
||||
html2text==2025.4.15
|
||||
huggingface-hub==1.5.*
|
||||
jinja2==3.1.6
|
||||
markdown
|
||||
|
|
@ -11,11 +10,12 @@ python-docx==1.1.2
|
|||
pyyaml
|
||||
requests
|
||||
rich
|
||||
trafilatura==2.0.0
|
||||
tqdm
|
||||
|
||||
# Gradio
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
|
||||
|
||||
# API
|
||||
flask_cloudflared==0.0.15
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
audioop-lts<1.0; python_version >= "3.13"
|
||||
fastapi==0.112.4
|
||||
html2text==2025.4.15
|
||||
huggingface-hub==1.5.*
|
||||
jinja2==3.1.6
|
||||
markdown
|
||||
|
|
@ -11,11 +10,12 @@ python-docx==1.1.2
|
|||
pyyaml
|
||||
requests
|
||||
rich
|
||||
trafilatura==2.0.0
|
||||
tqdm
|
||||
|
||||
# Gradio
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl
|
||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl
|
||||
|
||||
# API
|
||||
flask_cloudflared==0.0.15
|
||||
|
|
@ -23,5 +23,5 @@ sse-starlette==1.6.5
|
|||
tiktoken
|
||||
|
||||
# Vulkan wheels
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+vulkan-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+vulkan-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+vulkan-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+vulkan-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||
|
|
|
|||
99
server.py
99
server.py
|
|
@ -1,58 +1,20 @@
|
|||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
from modules import shared, ui # ui must be imported early to avoid circular imports
|
||||
from modules.image_models import load_image_model
|
||||
from modules.logging_colors import logger
|
||||
from modules.prompts import load_prompt
|
||||
|
||||
# Set up Gradio temp directory path
|
||||
gradio_temp_path = shared.user_data_dir / 'cache' / 'gradio'
|
||||
shutil.rmtree(gradio_temp_path, ignore_errors=True)
|
||||
gradio_temp_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Set environment variables
|
||||
os.environ.update({
|
||||
'GRADIO_ANALYTICS_ENABLED': 'False',
|
||||
'BITSANDBYTES_NOWELCOME': '1',
|
||||
'GRADIO_TEMP_DIR': str(gradio_temp_path)
|
||||
})
|
||||
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated')
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict')
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='The value passed into gr.Dropdown()')
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_names" has conflict')
|
||||
|
||||
import gradio as gr
|
||||
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from threading import Lock, Thread
|
||||
|
||||
import yaml
|
||||
|
||||
from modules import shared, utils
|
||||
from modules.image_models import load_image_model
|
||||
from modules.logging_colors import logger
|
||||
from modules.prompts import load_prompt
|
||||
|
||||
import modules.extensions as extensions_module
|
||||
from modules import (
|
||||
training,
|
||||
ui,
|
||||
ui_chat,
|
||||
ui_default,
|
||||
ui_file_saving,
|
||||
ui_image_generation,
|
||||
ui_model_menu,
|
||||
ui_notebook,
|
||||
ui_parameters,
|
||||
ui_session,
|
||||
utils
|
||||
)
|
||||
from modules.chat import generate_pfp_cache
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model, unload_model_if_idle
|
||||
from modules.models_settings import (
|
||||
|
|
@ -61,10 +23,20 @@ from modules.models_settings import (
|
|||
update_model_parameters
|
||||
)
|
||||
from modules.shared import do_cmd_flags_warnings
|
||||
from modules.utils import gradio
|
||||
|
||||
os.environ['BITSANDBYTES_NOWELCOME'] = '1'
|
||||
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated')
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict')
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_names" has conflict')
|
||||
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
# On second Ctrl+C, force an immediate exit
|
||||
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
||||
signal.signal(signal.SIGTERM, signal.SIG_DFL)
|
||||
|
||||
logger.info("Received Ctrl+C. Shutting down Text Generation Web UI gracefully.")
|
||||
|
||||
# Explicitly stop LlamaServer to avoid __del__ cleanup issues during shutdown
|
||||
|
|
@ -83,6 +55,37 @@ signal.signal(signal.SIGTERM, signal_handler)
|
|||
|
||||
def create_interface():
|
||||
|
||||
import shutil
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules import (
|
||||
training,
|
||||
ui,
|
||||
ui_chat,
|
||||
ui_default,
|
||||
ui_file_saving,
|
||||
ui_image_generation,
|
||||
ui_model_menu,
|
||||
ui_notebook,
|
||||
ui_parameters,
|
||||
ui_session,
|
||||
)
|
||||
from modules.chat import generate_pfp_cache
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.utils import gradio
|
||||
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='The value passed into gr.Dropdown()')
|
||||
|
||||
# Set up Gradio temp directory path
|
||||
gradio_temp_path = shared.user_data_dir / 'cache' / 'gradio'
|
||||
shutil.rmtree(gradio_temp_path, ignore_errors=True)
|
||||
gradio_temp_path.mkdir(parents=True, exist_ok=True)
|
||||
os.environ.update({
|
||||
'GRADIO_ANALYTICS_ENABLED': 'False',
|
||||
'GRADIO_TEMP_DIR': str(gradio_temp_path)
|
||||
})
|
||||
|
||||
title = 'Text Generation Web UI'
|
||||
|
||||
# Password authentication
|
||||
|
|
@ -215,6 +218,10 @@ def create_interface():
|
|||
|
||||
shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, gradio(ui.list_interface_input_elements()), show_progress=False)
|
||||
|
||||
# Sync theme_state with the actual client-side theme so that
|
||||
# autosave always writes the correct dark_theme value.
|
||||
shared.gradio['interface'].load(None, None, gradio('theme_state'), js='() => document.body.classList.contains("dark") ? "dark" : "light"')
|
||||
|
||||
extensions_module.create_extensions_tabs() # Extensions tabs
|
||||
extensions_module.create_extensions_block() # Extensions block
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ setlocal enabledelayedexpansion
|
|||
set PYTHONNOUSERSITE=1
|
||||
set PYTHONPATH=
|
||||
set PYTHONHOME=
|
||||
set PYTHONUTF8=1
|
||||
|
||||
cd /D "%~dp0"
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
min_p: 0.2
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
temperature: 0.7
|
||||
top_p: 0.8
|
||||
top_k: 20
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
temperature: 0.6
|
||||
top_p: 0.95
|
||||
top_k: 20
|
||||
1
user_data/presets/Top-P.yaml
Normal file
1
user_data/presets/Top-P.yaml
Normal file
|
|
@ -0,0 +1 @@
|
|||
top_p: 0.95
|
||||
|
|
@ -1 +0,0 @@
|
|||
min_p: 0.05
|
||||
52
user_data/tools/calculate.py
Normal file
52
user_data/tools/calculate.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
import ast
|
||||
import operator
|
||||
|
||||
OPERATORS = {
|
||||
ast.Add: operator.add,
|
||||
ast.Sub: operator.sub,
|
||||
ast.Mult: operator.mul,
|
||||
ast.Div: operator.truediv,
|
||||
ast.Pow: operator.pow,
|
||||
ast.Mod: operator.mod,
|
||||
ast.USub: operator.neg,
|
||||
}
|
||||
|
||||
|
||||
def _eval(node):
|
||||
if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
|
||||
return node.value
|
||||
elif isinstance(node, ast.BinOp) and type(node.op) in OPERATORS:
|
||||
left = _eval(node.left)
|
||||
right = _eval(node.right)
|
||||
if isinstance(node.op, ast.Pow) and isinstance(right, (int, float)) and abs(right) > 10000:
|
||||
raise ValueError("Exponent too large (max 10000)")
|
||||
return OPERATORS[type(node.op)](left, right)
|
||||
elif isinstance(node, ast.UnaryOp) and type(node.op) in OPERATORS:
|
||||
return OPERATORS[type(node.op)](_eval(node.operand))
|
||||
raise ValueError(f"Unsupported expression")
|
||||
|
||||
|
||||
tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculate",
|
||||
"description": "Evaluate a math expression. Supports +, -, *, /, **, %.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {"type": "string", "description": "The math expression to evaluate (e.g. '2 * (3 + 4)')."},
|
||||
},
|
||||
"required": ["expression"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def execute(arguments):
|
||||
expr = arguments.get("expression", "")
|
||||
try:
|
||||
tree = ast.parse(expr, mode='eval')
|
||||
result = _eval(tree.body)
|
||||
return {"expression": expr, "result": result}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
30
user_data/tools/fetch_webpage.py
Normal file
30
user_data/tools/fetch_webpage.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
from modules.web_search import download_web_page, truncate_content_by_tokens
|
||||
|
||||
tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "fetch_webpage",
|
||||
"description": "Fetch and read the contents of a web page given its URL. Returns the page content as plain text.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {"type": "string", "description": "The URL of the web page to fetch."},
|
||||
"max_tokens": {"type": "integer", "description": "Maximum number of tokens in the returned content (default: 2048)."},
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def execute(arguments):
|
||||
url = arguments.get("url", "")
|
||||
max_tokens = arguments.get("max_tokens", 2048)
|
||||
if not url:
|
||||
return {"error": "No URL provided."}
|
||||
|
||||
content = download_web_page(url, include_links=True)
|
||||
if not content or not content.strip():
|
||||
return {"error": f"Failed to fetch content from {url}"}
|
||||
|
||||
return {"url": url, "content": truncate_content_by_tokens(content, max_tokens=max_tokens)}
|
||||
18
user_data/tools/get_datetime.py
Normal file
18
user_data/tools/get_datetime.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from datetime import datetime
|
||||
|
||||
tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_datetime",
|
||||
"description": "Get the current date and time.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def execute(arguments):
|
||||
now = datetime.now()
|
||||
return {"date": now.strftime("%Y-%m-%d"), "time": now.strftime("%I:%M %p")}
|
||||
23
user_data/tools/roll_dice.py
Normal file
23
user_data/tools/roll_dice.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import random
|
||||
|
||||
tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "roll_dice",
|
||||
"description": "Roll one or more dice with the specified number of sides.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {"type": "integer", "description": "Number of dice to roll.", "default": 1},
|
||||
"sides": {"type": "integer", "description": "Number of sides per die.", "default": 20},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def execute(arguments):
|
||||
count = max(1, min(arguments.get("count", 1), 1000))
|
||||
sides = max(2, min(arguments.get("sides", 20), 1000))
|
||||
rolls = [random.randint(1, sides) for _ in range(count)]
|
||||
return {"rolls": rolls, "total": sum(rolls)}
|
||||
27
user_data/tools/web_search.py
Normal file
27
user_data/tools/web_search.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
from modules.web_search import perform_web_search
|
||||
|
||||
tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web using DuckDuckGo and return a list of result titles and URLs. Use fetch_webpage to read the contents of a specific result.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "The search query."},
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def execute(arguments):
|
||||
query = arguments.get("query", "")
|
||||
results = perform_web_search(query, num_pages=None, fetch_content=False)
|
||||
output = []
|
||||
for r in results:
|
||||
if r:
|
||||
output.append({"title": r["title"], "url": r["url"]})
|
||||
|
||||
return output if output else [{"error": "No results found."}]
|
||||
Loading…
Reference in a new issue