mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-08 16:13:41 +00:00
Compare commits
No commits in common. "main" and "v4.0" have entirely different histories.
169 changed files with 4376 additions and 6443 deletions
37
.github/workflows/build-everything-tgw.yml
vendored
37
.github/workflows/build-everything-tgw.yml
vendored
|
|
@ -41,13 +41,6 @@ jobs:
|
||||||
version: ${{ inputs.version }}
|
version: ${{ inputs.version }}
|
||||||
config: 'os:ubuntu-22.04'
|
config: 'os:ubuntu-22.04'
|
||||||
|
|
||||||
build_release_rocm_windows:
|
|
||||||
name: ROCm Windows
|
|
||||||
uses: ./.github/workflows/build-portable-release-rocm.yml
|
|
||||||
with:
|
|
||||||
version: ${{ inputs.version }}
|
|
||||||
config: 'os:windows-2022'
|
|
||||||
|
|
||||||
build_release_rocm_linux:
|
build_release_rocm_linux:
|
||||||
name: ROCm Linux
|
name: ROCm Linux
|
||||||
uses: ./.github/workflows/build-portable-release-rocm.yml
|
uses: ./.github/workflows/build-portable-release-rocm.yml
|
||||||
|
|
@ -74,32 +67,4 @@ jobs:
|
||||||
uses: ./.github/workflows/build-portable-release.yml
|
uses: ./.github/workflows/build-portable-release.yml
|
||||||
with:
|
with:
|
||||||
version: ${{ inputs.version }}
|
version: ${{ inputs.version }}
|
||||||
config: 'os:macos-15-intel,macos-14'
|
config: 'os:macos-13,macos-14'
|
||||||
|
|
||||||
build_release_ik_cuda_windows:
|
|
||||||
name: ik CUDA Windows
|
|
||||||
uses: ./.github/workflows/build-portable-release-ik-cuda.yml
|
|
||||||
with:
|
|
||||||
version: ${{ inputs.version }}
|
|
||||||
config: 'os:windows-2022'
|
|
||||||
|
|
||||||
build_release_ik_cuda_linux:
|
|
||||||
name: ik CUDA Linux
|
|
||||||
uses: ./.github/workflows/build-portable-release-ik-cuda.yml
|
|
||||||
with:
|
|
||||||
version: ${{ inputs.version }}
|
|
||||||
config: 'os:ubuntu-22.04'
|
|
||||||
|
|
||||||
build_release_ik_cpu_windows:
|
|
||||||
name: ik CPU Windows
|
|
||||||
uses: ./.github/workflows/build-portable-release-ik.yml
|
|
||||||
with:
|
|
||||||
version: ${{ inputs.version }}
|
|
||||||
config: 'os:windows-2022'
|
|
||||||
|
|
||||||
build_release_ik_cpu_linux:
|
|
||||||
name: ik CPU Linux
|
|
||||||
uses: ./.github/workflows/build-portable-release-ik.yml
|
|
||||||
with:
|
|
||||||
version: ${{ inputs.version }}
|
|
||||||
config: 'os:ubuntu-22.04'
|
|
||||||
|
|
|
||||||
|
|
@ -106,7 +106,7 @@ jobs:
|
||||||
cd "text-generation-webui-${VERSION_CLEAN}"
|
cd "text-generation-webui-${VERSION_CLEAN}"
|
||||||
|
|
||||||
# Remove extensions that need additional requirements
|
# Remove extensions that need additional requirements
|
||||||
allowed=("character_bias" "gallery" "sd_api_pictures")
|
allowed=("character_bias" "gallery" "openai" "sd_api_pictures")
|
||||||
find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf
|
find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf
|
||||||
|
|
||||||
# Define common variables
|
# Define common variables
|
||||||
|
|
@ -116,13 +116,13 @@ jobs:
|
||||||
# 1. Set platform-specific variables
|
# 1. Set platform-specific variables
|
||||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
PLATFORM="windows"
|
PLATFORM="windows"
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-pc-windows-msvc-install_only_stripped.tar.gz"
|
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-pc-windows-msvc-install_only.tar.gz"
|
||||||
PIP_PATH="portable_env/python.exe -m pip"
|
PIP_PATH="portable_env/python.exe -m pip"
|
||||||
PACKAGES_PATH="portable_env/Lib/site-packages"
|
PACKAGES_PATH="portable_env/Lib/site-packages"
|
||||||
rm start_linux.sh start_macos.sh
|
rm start_linux.sh start_macos.sh
|
||||||
else
|
else
|
||||||
PLATFORM="linux"
|
PLATFORM="linux"
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-unknown-linux-gnu-install_only_stripped.tar.gz"
|
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-unknown-linux-gnu-install_only.tar.gz"
|
||||||
PIP_PATH="portable_env/bin/python -m pip"
|
PIP_PATH="portable_env/bin/python -m pip"
|
||||||
PACKAGES_PATH="portable_env/lib/python3.13/site-packages"
|
PACKAGES_PATH="portable_env/lib/python3.13/site-packages"
|
||||||
rm start_macos.sh start_windows.bat
|
rm start_macos.sh start_windows.bat
|
||||||
|
|
@ -150,16 +150,15 @@ jobs:
|
||||||
# 5. Clean up
|
# 5. Clean up
|
||||||
rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py
|
rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py
|
||||||
|
|
||||||
# 6. Create archive
|
# 6. Create ZIP file
|
||||||
cd ..
|
cd ..
|
||||||
|
ZIP_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-cuda${CUDA_VERSION}.zip"
|
||||||
|
echo "Creating archive: $ZIP_NAME"
|
||||||
|
|
||||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-cuda${CUDA_VERSION}.zip"
|
powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ZIP_NAME"
|
||||||
echo "Creating archive: $ARCHIVE_NAME"
|
|
||||||
powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ARCHIVE_NAME"
|
|
||||||
else
|
else
|
||||||
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-cuda${CUDA_VERSION}.tar.gz"
|
zip -r "$ZIP_NAME" "text-generation-webui-${VERSION_CLEAN}"
|
||||||
echo "Creating archive: $ARCHIVE_NAME"
|
|
||||||
tar czf "$ARCHIVE_NAME" "text-generation-webui-${VERSION_CLEAN}"
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Upload files to a GitHub release
|
- name: Upload files to a GitHub release
|
||||||
|
|
@ -168,7 +167,7 @@ jobs:
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
with:
|
with:
|
||||||
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
file: ../textgen-portable-*
|
file: ../textgen-portable-*.zip
|
||||||
tag: ${{ inputs.version }}
|
tag: ${{ inputs.version }}
|
||||||
file_glob: true
|
file_glob: true
|
||||||
make_latest: false
|
make_latest: false
|
||||||
|
|
|
||||||
178
.github/workflows/build-portable-release-ik-cuda.yml
vendored
178
.github/workflows/build-portable-release-ik-cuda.yml
vendored
|
|
@ -1,178 +0,0 @@
|
||||||
name: Build ik CUDA
|
|
||||||
|
|
||||||
on:
|
|
||||||
workflow_dispatch:
|
|
||||||
inputs:
|
|
||||||
version:
|
|
||||||
description: 'Version tag of text-generation-webui to build: v3.0'
|
|
||||||
default: 'v3.0'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
config:
|
|
||||||
description: 'Override configurations to build: key1:item1-1,item1-2;key2:item2-1,item2-2'
|
|
||||||
default: 'Default'
|
|
||||||
required: false
|
|
||||||
type: string
|
|
||||||
exclude:
|
|
||||||
description: 'Exclude build configurations: key1-1:item1-1,key1-2:item1-2;key2-1:item2-1,key2-2:item2-2'
|
|
||||||
default: 'None'
|
|
||||||
required: false
|
|
||||||
type: string
|
|
||||||
workflow_call:
|
|
||||||
inputs:
|
|
||||||
version:
|
|
||||||
description: 'Version tag of text-generation-webui to build: v3.0'
|
|
||||||
default: 'v3.0'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
config:
|
|
||||||
description: 'Configurations to build: key1:item1-1,item1-2;key2:item2-1,item2-2'
|
|
||||||
default: 'Default'
|
|
||||||
required: false
|
|
||||||
type: string
|
|
||||||
exclude:
|
|
||||||
description: 'Exclude build configurations: key1-1:item1-1,key1-2:item1-2;key2-1:item2-1,key2-2:item2-2'
|
|
||||||
default: 'None'
|
|
||||||
required: false
|
|
||||||
type: string
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
define_matrix:
|
|
||||||
name: Define Build Matrix
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
outputs:
|
|
||||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: pwsh
|
|
||||||
env:
|
|
||||||
CONFIGIN: ${{ inputs.config }}
|
|
||||||
EXCLUDEIN: ${{ inputs.exclude }}
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Define Job Output
|
|
||||||
id: set-matrix
|
|
||||||
run: |
|
|
||||||
$matrix = @{
|
|
||||||
'os' = @('ubuntu-22.04', 'windows-2022')
|
|
||||||
'pyver' = @("3.13")
|
|
||||||
'cuda' = @("12.4", "13.1")
|
|
||||||
}
|
|
||||||
|
|
||||||
if ($env:CONFIGIN -ne 'Default') {$env:CONFIGIN.split(';').foreach({$matrix[$_.split(':')[0]] = $_.split(':')[1].split(',')})}
|
|
||||||
|
|
||||||
if ($env:EXCLUDEIN -ne 'None') {
|
|
||||||
$exclusions = @()
|
|
||||||
$exclusions += $env:EXCLUDEIN.split(';').replace(':','=').replace(',',"`n") | ConvertFrom-StringData
|
|
||||||
$matrix['exclude'] = $exclusions
|
|
||||||
}
|
|
||||||
|
|
||||||
$matrixOut = ConvertTo-Json $matrix -Compress
|
|
||||||
Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT
|
|
||||||
|
|
||||||
build_wheels:
|
|
||||||
name: ${{ matrix.os }} ${{ matrix.pyver }} CUDA ${{ matrix.cuda }}
|
|
||||||
needs: define_matrix
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
strategy:
|
|
||||||
matrix: ${{ fromJSON(needs.define_matrix.outputs.matrix) }}
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: pwsh
|
|
||||||
env:
|
|
||||||
PCKGVER: ${{ inputs.version }}
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v6
|
|
||||||
with:
|
|
||||||
repository: 'oobabooga/text-generation-webui'
|
|
||||||
ref: ${{ inputs.version }}
|
|
||||||
submodules: 'recursive'
|
|
||||||
|
|
||||||
- uses: actions/setup-python@v6
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.pyver }}
|
|
||||||
|
|
||||||
- name: Build Package
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
VERSION_CLEAN="${{ inputs.version }}"
|
|
||||||
VERSION_CLEAN="${VERSION_CLEAN#v}"
|
|
||||||
cd ..
|
|
||||||
cp -r text-generation-webui "text-generation-webui-ik-${VERSION_CLEAN}"
|
|
||||||
cd "text-generation-webui-ik-${VERSION_CLEAN}"
|
|
||||||
|
|
||||||
# Remove extensions that need additional requirements
|
|
||||||
allowed=("character_bias" "gallery" "sd_api_pictures")
|
|
||||||
find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf
|
|
||||||
|
|
||||||
# Define common variables
|
|
||||||
CUDA_VERSION="${{ matrix.cuda }}"
|
|
||||||
VERSION="${{ inputs.version }}"
|
|
||||||
|
|
||||||
# 1. Set platform-specific variables
|
|
||||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
|
||||||
PLATFORM="windows"
|
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-pc-windows-msvc-install_only_stripped.tar.gz"
|
|
||||||
PIP_PATH="portable_env/python.exe -m pip"
|
|
||||||
PACKAGES_PATH="portable_env/Lib/site-packages"
|
|
||||||
rm start_linux.sh start_macos.sh
|
|
||||||
else
|
|
||||||
PLATFORM="linux"
|
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-unknown-linux-gnu-install_only_stripped.tar.gz"
|
|
||||||
PIP_PATH="portable_env/bin/python -m pip"
|
|
||||||
PACKAGES_PATH="portable_env/lib/python3.13/site-packages"
|
|
||||||
rm start_macos.sh start_windows.bat
|
|
||||||
fi
|
|
||||||
|
|
||||||
# 2. Download and extract Python
|
|
||||||
cd ..
|
|
||||||
echo "Downloading Python for $PLATFORM..."
|
|
||||||
curl -L -o python-build.tar.gz "$PYTHON_URL"
|
|
||||||
tar -xzf python-build.tar.gz
|
|
||||||
mv python "text-generation-webui-ik-${VERSION_CLEAN}/portable_env"
|
|
||||||
|
|
||||||
# 3. Prepare requirements file based on CUDA version
|
|
||||||
cd "text-generation-webui-ik-${VERSION_CLEAN}"
|
|
||||||
if [[ "$CUDA_VERSION" == "13.1" ]]; then
|
|
||||||
REQ_FILE="requirements/portable/requirements_ik_cuda131.txt"
|
|
||||||
else
|
|
||||||
REQ_FILE="requirements/portable/requirements_ik.txt"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# 4. Inject --ik into start scripts
|
|
||||||
sed -i 's/--portable/--portable --ik/g' start_linux.sh start_windows.bat 2>/dev/null || true
|
|
||||||
|
|
||||||
# 5. Install packages
|
|
||||||
echo "Installing Python packages from $REQ_FILE..."
|
|
||||||
$PIP_PATH install --target="./$PACKAGES_PATH" -r "$REQ_FILE"
|
|
||||||
|
|
||||||
# 6. Clean up
|
|
||||||
rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py
|
|
||||||
|
|
||||||
# 7. Create archive
|
|
||||||
cd ..
|
|
||||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
|
||||||
ARCHIVE_NAME="textgen-portable-ik-${VERSION_CLEAN}-${PLATFORM}-cuda${CUDA_VERSION}.zip"
|
|
||||||
echo "Creating archive: $ARCHIVE_NAME"
|
|
||||||
powershell -Command "Compress-Archive -Path text-generation-webui-ik-${VERSION_CLEAN} -DestinationPath $ARCHIVE_NAME"
|
|
||||||
else
|
|
||||||
ARCHIVE_NAME="textgen-portable-ik-${VERSION_CLEAN}-${PLATFORM}-cuda${CUDA_VERSION}.tar.gz"
|
|
||||||
echo "Creating archive: $ARCHIVE_NAME"
|
|
||||||
tar czf "$ARCHIVE_NAME" "text-generation-webui-ik-${VERSION_CLEAN}"
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Upload files to a GitHub release
|
|
||||||
id: upload-release
|
|
||||||
uses: svenstaro/upload-release-action@2.7.0
|
|
||||||
continue-on-error: true
|
|
||||||
with:
|
|
||||||
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
file: ../textgen-portable-ik-*
|
|
||||||
tag: ${{ inputs.version }}
|
|
||||||
file_glob: true
|
|
||||||
make_latest: false
|
|
||||||
overwrite: true
|
|
||||||
173
.github/workflows/build-portable-release-ik.yml
vendored
173
.github/workflows/build-portable-release-ik.yml
vendored
|
|
@ -1,173 +0,0 @@
|
||||||
name: Build ik CPU
|
|
||||||
|
|
||||||
on:
|
|
||||||
workflow_dispatch:
|
|
||||||
inputs:
|
|
||||||
version:
|
|
||||||
description: 'Version tag of text-generation-webui to build: v3.0'
|
|
||||||
default: 'v3.0'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
config:
|
|
||||||
description: 'Override configurations to build: key1:item1-1,item1-2;key2:item2-1,item2-2'
|
|
||||||
default: 'Default'
|
|
||||||
required: false
|
|
||||||
type: string
|
|
||||||
exclude:
|
|
||||||
description: 'Exclude build configurations: key1-1:item1-1,key1-2:item1-2;key2-1:item2-1,key2-2:item2-2'
|
|
||||||
default: 'None'
|
|
||||||
required: false
|
|
||||||
type: string
|
|
||||||
workflow_call:
|
|
||||||
inputs:
|
|
||||||
version:
|
|
||||||
description: 'Version tag of text-generation-webui to build: v3.0'
|
|
||||||
default: 'v3.0'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
config:
|
|
||||||
description: 'Configurations to build: key1:item1-1,item1-2;key2:item2-1,item2-2'
|
|
||||||
default: 'Default'
|
|
||||||
required: false
|
|
||||||
type: string
|
|
||||||
exclude:
|
|
||||||
description: 'Exclude build configurations: key1-1:item1-1,key1-2:item1-2;key2-1:item2-1,key2-2:item2-2'
|
|
||||||
default: 'None'
|
|
||||||
required: false
|
|
||||||
type: string
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
define_matrix:
|
|
||||||
name: Define Build Matrix
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
outputs:
|
|
||||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: pwsh
|
|
||||||
env:
|
|
||||||
CONFIGIN: ${{ inputs.config }}
|
|
||||||
EXCLUDEIN: ${{ inputs.exclude }}
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Define Job Output
|
|
||||||
id: set-matrix
|
|
||||||
run: |
|
|
||||||
$matrix = @{
|
|
||||||
'os' = @('ubuntu-22.04', 'windows-2022')
|
|
||||||
'pyver' = @("3.13")
|
|
||||||
}
|
|
||||||
|
|
||||||
if ($env:CONFIGIN -ne 'Default') {$env:CONFIGIN.split(';').foreach({$matrix[$_.split(':')[0]] = $_.split(':')[1].split(',')})}
|
|
||||||
|
|
||||||
if ($env:EXCLUDEIN -ne 'None') {
|
|
||||||
$exclusions = @()
|
|
||||||
$exclusions += $env:EXCLUDEIN.split(';').replace(':','=').replace(',',"`n") | ConvertFrom-StringData
|
|
||||||
$matrix['exclude'] = $exclusions
|
|
||||||
}
|
|
||||||
|
|
||||||
$matrixOut = ConvertTo-Json $matrix -Compress
|
|
||||||
Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT
|
|
||||||
|
|
||||||
build_wheels:
|
|
||||||
name: ${{ matrix.os }} ${{ matrix.pyver }}
|
|
||||||
needs: define_matrix
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
strategy:
|
|
||||||
matrix: ${{ fromJSON(needs.define_matrix.outputs.matrix) }}
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: pwsh
|
|
||||||
env:
|
|
||||||
PCKGVER: ${{ inputs.version }}
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v6
|
|
||||||
with:
|
|
||||||
repository: 'oobabooga/text-generation-webui'
|
|
||||||
ref: ${{ inputs.version }}
|
|
||||||
submodules: 'recursive'
|
|
||||||
|
|
||||||
- uses: actions/setup-python@v6
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.pyver }}
|
|
||||||
|
|
||||||
- name: Build Package
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
VERSION_CLEAN="${{ inputs.version }}"
|
|
||||||
VERSION_CLEAN="${VERSION_CLEAN#v}"
|
|
||||||
cd ..
|
|
||||||
cp -r text-generation-webui "text-generation-webui-ik-${VERSION_CLEAN}"
|
|
||||||
cd "text-generation-webui-ik-${VERSION_CLEAN}"
|
|
||||||
|
|
||||||
# Remove extensions that need additional requirements
|
|
||||||
allowed=("character_bias" "gallery" "sd_api_pictures")
|
|
||||||
find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf
|
|
||||||
|
|
||||||
# Define common variables
|
|
||||||
VERSION="${{ inputs.version }}"
|
|
||||||
|
|
||||||
# 1. Set platform-specific variables
|
|
||||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
|
||||||
PLATFORM="windows-cpu"
|
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-pc-windows-msvc-install_only_stripped.tar.gz"
|
|
||||||
PIP_PATH="portable_env/python.exe -m pip"
|
|
||||||
PACKAGES_PATH="portable_env/Lib/site-packages"
|
|
||||||
rm start_linux.sh start_macos.sh
|
|
||||||
else
|
|
||||||
PLATFORM="linux-cpu"
|
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-unknown-linux-gnu-install_only_stripped.tar.gz"
|
|
||||||
PIP_PATH="portable_env/bin/python -m pip"
|
|
||||||
PACKAGES_PATH="portable_env/lib/python3.13/site-packages"
|
|
||||||
rm start_macos.sh start_windows.bat
|
|
||||||
fi
|
|
||||||
|
|
||||||
# 2. Download and extract Python
|
|
||||||
echo "Downloading Python for $PLATFORM..."
|
|
||||||
cd ..
|
|
||||||
curl -L -o python-build.tar.gz "$PYTHON_URL"
|
|
||||||
tar -xzf python-build.tar.gz
|
|
||||||
mv python "text-generation-webui-ik-${VERSION_CLEAN}/portable_env"
|
|
||||||
|
|
||||||
# 3. Prepare requirements file
|
|
||||||
cd "text-generation-webui-ik-${VERSION_CLEAN}"
|
|
||||||
REQ_FILE="requirements/portable/requirements_ik_cpu_only.txt"
|
|
||||||
echo "Using requirements file: $REQ_FILE"
|
|
||||||
|
|
||||||
# 4. Inject --ik into start scripts
|
|
||||||
sed -i 's/--portable/--portable --ik/g' start_linux.sh start_windows.bat 2>/dev/null || true
|
|
||||||
|
|
||||||
# 5. Install packages
|
|
||||||
echo "Installing Python packages from $REQ_FILE..."
|
|
||||||
$PIP_PATH install --target="./$PACKAGES_PATH" -r "$REQ_FILE"
|
|
||||||
|
|
||||||
# 6. Clean up
|
|
||||||
rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py
|
|
||||||
|
|
||||||
# 7. Create archive
|
|
||||||
cd ..
|
|
||||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
|
||||||
ARCHIVE_NAME="textgen-portable-ik-${VERSION_CLEAN}-${PLATFORM}.zip"
|
|
||||||
echo "Creating archive: $ARCHIVE_NAME"
|
|
||||||
powershell -Command "Compress-Archive -Path text-generation-webui-ik-${VERSION_CLEAN} -DestinationPath $ARCHIVE_NAME"
|
|
||||||
else
|
|
||||||
ARCHIVE_NAME="textgen-portable-ik-${VERSION_CLEAN}-${PLATFORM}.tar.gz"
|
|
||||||
echo "Creating archive: $ARCHIVE_NAME"
|
|
||||||
tar czf "$ARCHIVE_NAME" "text-generation-webui-ik-${VERSION_CLEAN}"
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Upload files to a GitHub release
|
|
||||||
id: upload-release
|
|
||||||
uses: svenstaro/upload-release-action@2.7.0
|
|
||||||
continue-on-error: true
|
|
||||||
with:
|
|
||||||
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
file: ../textgen-portable-ik-*
|
|
||||||
tag: ${{ inputs.version }}
|
|
||||||
file_glob: true
|
|
||||||
make_latest: false
|
|
||||||
overwrite: true
|
|
||||||
|
|
@ -105,7 +105,7 @@ jobs:
|
||||||
cd "text-generation-webui-${VERSION_CLEAN}"
|
cd "text-generation-webui-${VERSION_CLEAN}"
|
||||||
|
|
||||||
# Remove extensions that need additional requirements
|
# Remove extensions that need additional requirements
|
||||||
allowed=("character_bias" "gallery" "sd_api_pictures")
|
allowed=("character_bias" "gallery" "openai" "sd_api_pictures")
|
||||||
find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf
|
find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf
|
||||||
|
|
||||||
# Define common variables
|
# Define common variables
|
||||||
|
|
@ -114,13 +114,13 @@ jobs:
|
||||||
# 1. Set platform-specific variables
|
# 1. Set platform-specific variables
|
||||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
PLATFORM="windows"
|
PLATFORM="windows"
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-pc-windows-msvc-install_only_stripped.tar.gz"
|
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-pc-windows-msvc-install_only.tar.gz"
|
||||||
PIP_PATH="portable_env/python.exe -m pip"
|
PIP_PATH="portable_env/python.exe -m pip"
|
||||||
PACKAGES_PATH="portable_env/Lib/site-packages"
|
PACKAGES_PATH="portable_env/Lib/site-packages"
|
||||||
rm start_linux.sh start_macos.sh
|
rm start_linux.sh start_macos.sh
|
||||||
else
|
else
|
||||||
PLATFORM="linux"
|
PLATFORM="linux"
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-unknown-linux-gnu-install_only_stripped.tar.gz"
|
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-unknown-linux-gnu-install_only.tar.gz"
|
||||||
PIP_PATH="portable_env/bin/python -m pip"
|
PIP_PATH="portable_env/bin/python -m pip"
|
||||||
PACKAGES_PATH="portable_env/lib/python3.13/site-packages"
|
PACKAGES_PATH="portable_env/lib/python3.13/site-packages"
|
||||||
rm start_macos.sh start_windows.bat
|
rm start_macos.sh start_windows.bat
|
||||||
|
|
@ -145,16 +145,15 @@ jobs:
|
||||||
# 5. Clean up
|
# 5. Clean up
|
||||||
rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py
|
rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py
|
||||||
|
|
||||||
# 6. Create archive
|
# 6. Create ZIP file
|
||||||
cd ..
|
cd ..
|
||||||
|
ZIP_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm.zip"
|
||||||
|
echo "Creating archive: $ZIP_NAME"
|
||||||
|
|
||||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm7.2.zip"
|
powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ZIP_NAME"
|
||||||
echo "Creating archive: $ARCHIVE_NAME"
|
|
||||||
powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ARCHIVE_NAME"
|
|
||||||
else
|
else
|
||||||
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm7.2.tar.gz"
|
zip -r "$ZIP_NAME" "text-generation-webui-${VERSION_CLEAN}"
|
||||||
echo "Creating archive: $ARCHIVE_NAME"
|
|
||||||
tar czf "$ARCHIVE_NAME" "text-generation-webui-${VERSION_CLEAN}"
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Upload files to a GitHub release
|
- name: Upload files to a GitHub release
|
||||||
|
|
@ -163,7 +162,7 @@ jobs:
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
with:
|
with:
|
||||||
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
file: ../textgen-portable-*
|
file: ../textgen-portable-*.zip
|
||||||
tag: ${{ inputs.version }}
|
tag: ${{ inputs.version }}
|
||||||
file_glob: true
|
file_glob: true
|
||||||
make_latest: false
|
make_latest: false
|
||||||
|
|
|
||||||
|
|
@ -105,7 +105,7 @@ jobs:
|
||||||
cd "text-generation-webui-${VERSION_CLEAN}"
|
cd "text-generation-webui-${VERSION_CLEAN}"
|
||||||
|
|
||||||
# Remove extensions that need additional requirements
|
# Remove extensions that need additional requirements
|
||||||
allowed=("character_bias" "gallery" "sd_api_pictures")
|
allowed=("character_bias" "gallery" "openai" "sd_api_pictures")
|
||||||
find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf
|
find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf
|
||||||
|
|
||||||
# Define common variables
|
# Define common variables
|
||||||
|
|
@ -114,13 +114,13 @@ jobs:
|
||||||
# 1. Set platform-specific variables
|
# 1. Set platform-specific variables
|
||||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
PLATFORM="windows"
|
PLATFORM="windows"
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-pc-windows-msvc-install_only_stripped.tar.gz"
|
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-pc-windows-msvc-install_only.tar.gz"
|
||||||
PIP_PATH="portable_env/python.exe -m pip"
|
PIP_PATH="portable_env/python.exe -m pip"
|
||||||
PACKAGES_PATH="portable_env/Lib/site-packages"
|
PACKAGES_PATH="portable_env/Lib/site-packages"
|
||||||
rm start_linux.sh start_macos.sh
|
rm start_linux.sh start_macos.sh
|
||||||
else
|
else
|
||||||
PLATFORM="linux"
|
PLATFORM="linux"
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-unknown-linux-gnu-install_only_stripped.tar.gz"
|
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-unknown-linux-gnu-install_only.tar.gz"
|
||||||
PIP_PATH="portable_env/bin/python -m pip"
|
PIP_PATH="portable_env/bin/python -m pip"
|
||||||
PACKAGES_PATH="portable_env/lib/python3.13/site-packages"
|
PACKAGES_PATH="portable_env/lib/python3.13/site-packages"
|
||||||
rm start_macos.sh start_windows.bat
|
rm start_macos.sh start_windows.bat
|
||||||
|
|
@ -145,16 +145,15 @@ jobs:
|
||||||
# 5. Clean up
|
# 5. Clean up
|
||||||
rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py
|
rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py
|
||||||
|
|
||||||
# 6. Create archive
|
# 6. Create ZIP file
|
||||||
cd ..
|
cd ..
|
||||||
|
ZIP_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-vulkan.zip"
|
||||||
|
echo "Creating archive: $ZIP_NAME"
|
||||||
|
|
||||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-vulkan.zip"
|
powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ZIP_NAME"
|
||||||
echo "Creating archive: $ARCHIVE_NAME"
|
|
||||||
powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ARCHIVE_NAME"
|
|
||||||
else
|
else
|
||||||
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-vulkan.tar.gz"
|
zip -r "$ZIP_NAME" "text-generation-webui-${VERSION_CLEAN}"
|
||||||
echo "Creating archive: $ARCHIVE_NAME"
|
|
||||||
tar czf "$ARCHIVE_NAME" "text-generation-webui-${VERSION_CLEAN}"
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Upload files to a GitHub release
|
- name: Upload files to a GitHub release
|
||||||
|
|
@ -163,7 +162,7 @@ jobs:
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
with:
|
with:
|
||||||
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
file: ../textgen-portable-*
|
file: ../textgen-portable-*.zip
|
||||||
tag: ${{ inputs.version }}
|
tag: ${{ inputs.version }}
|
||||||
file_glob: true
|
file_glob: true
|
||||||
make_latest: false
|
make_latest: false
|
||||||
|
|
|
||||||
29
.github/workflows/build-portable-release.yml
vendored
29
.github/workflows/build-portable-release.yml
vendored
|
|
@ -105,7 +105,7 @@ jobs:
|
||||||
cd "text-generation-webui-${VERSION_CLEAN}"
|
cd "text-generation-webui-${VERSION_CLEAN}"
|
||||||
|
|
||||||
# Remove extensions that need additional requirements
|
# Remove extensions that need additional requirements
|
||||||
allowed=("character_bias" "gallery" "sd_api_pictures")
|
allowed=("character_bias" "gallery" "openai" "sd_api_pictures")
|
||||||
find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf
|
find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf
|
||||||
|
|
||||||
# Define common variables
|
# Define common variables
|
||||||
|
|
@ -115,18 +115,18 @@ jobs:
|
||||||
# 1. Set platform-specific variables
|
# 1. Set platform-specific variables
|
||||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
PLATFORM="windows-cpu"
|
PLATFORM="windows-cpu"
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-pc-windows-msvc-install_only_stripped.tar.gz"
|
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-pc-windows-msvc-install_only.tar.gz"
|
||||||
PIP_PATH="portable_env/python.exe -m pip"
|
PIP_PATH="portable_env/python.exe -m pip"
|
||||||
PACKAGES_PATH="portable_env/Lib/site-packages"
|
PACKAGES_PATH="portable_env/Lib/site-packages"
|
||||||
rm start_linux.sh start_macos.sh
|
rm start_linux.sh start_macos.sh
|
||||||
elif [[ "$RUNNER_OS" == "macOS" ]]; then
|
elif [[ "$RUNNER_OS" == "macOS" ]]; then
|
||||||
if [[ "$OS_TYPE" == "macos-15-intel" ]]; then
|
if [[ "$OS_TYPE" == "macos-13" ]]; then
|
||||||
PLATFORM="macos-x86_64"
|
PLATFORM="macos-x86_64"
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-apple-darwin-install_only_stripped.tar.gz"
|
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-apple-darwin-install_only.tar.gz"
|
||||||
REQ_TYPE="apple_intel"
|
REQ_TYPE="apple_intel"
|
||||||
else
|
else
|
||||||
PLATFORM="macos-arm64"
|
PLATFORM="macos-arm64"
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-aarch64-apple-darwin-install_only_stripped.tar.gz"
|
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-aarch64-apple-darwin-install_only.tar.gz"
|
||||||
REQ_TYPE="apple_silicon"
|
REQ_TYPE="apple_silicon"
|
||||||
fi
|
fi
|
||||||
PIP_PATH="portable_env/bin/python -m pip"
|
PIP_PATH="portable_env/bin/python -m pip"
|
||||||
|
|
@ -135,7 +135,7 @@ jobs:
|
||||||
else
|
else
|
||||||
# Linux case
|
# Linux case
|
||||||
PLATFORM="linux-cpu"
|
PLATFORM="linux-cpu"
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-unknown-linux-gnu-install_only_stripped.tar.gz"
|
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-unknown-linux-gnu-install_only.tar.gz"
|
||||||
PIP_PATH="portable_env/bin/python -m pip"
|
PIP_PATH="portable_env/bin/python -m pip"
|
||||||
PACKAGES_PATH="portable_env/lib/python3.13/site-packages"
|
PACKAGES_PATH="portable_env/lib/python3.13/site-packages"
|
||||||
rm start_macos.sh start_windows.bat
|
rm start_macos.sh start_windows.bat
|
||||||
|
|
@ -153,7 +153,7 @@ jobs:
|
||||||
|
|
||||||
# Select requirements file based on platform
|
# Select requirements file based on platform
|
||||||
if [[ "$RUNNER_OS" == "macOS" ]]; then
|
if [[ "$RUNNER_OS" == "macOS" ]]; then
|
||||||
if [[ "$OS_TYPE" == "macos-15-intel" ]]; then
|
if [[ "$OS_TYPE" == "macos-13" ]]; then
|
||||||
REQ_FILE="requirements/portable/requirements_apple_intel.txt"
|
REQ_FILE="requirements/portable/requirements_apple_intel.txt"
|
||||||
else
|
else
|
||||||
REQ_FILE="requirements/portable/requirements_apple_silicon.txt"
|
REQ_FILE="requirements/portable/requirements_apple_silicon.txt"
|
||||||
|
|
@ -171,16 +171,15 @@ jobs:
|
||||||
# 5. Clean up
|
# 5. Clean up
|
||||||
rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py
|
rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py
|
||||||
|
|
||||||
# 6. Create archive
|
# 6. Create ZIP file
|
||||||
cd ..
|
cd ..
|
||||||
|
ZIP_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}.zip"
|
||||||
|
echo "Creating archive: $ZIP_NAME"
|
||||||
|
|
||||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}.zip"
|
powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ZIP_NAME"
|
||||||
echo "Creating archive: $ARCHIVE_NAME"
|
|
||||||
powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ARCHIVE_NAME"
|
|
||||||
else
|
else
|
||||||
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}.tar.gz"
|
zip -r "$ZIP_NAME" "text-generation-webui-${VERSION_CLEAN}"
|
||||||
echo "Creating archive: $ARCHIVE_NAME"
|
|
||||||
tar czf "$ARCHIVE_NAME" "text-generation-webui-${VERSION_CLEAN}"
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Upload files to a GitHub release
|
- name: Upload files to a GitHub release
|
||||||
|
|
@ -189,7 +188,7 @@ jobs:
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
with:
|
with:
|
||||||
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
file: ../textgen-portable-*
|
file: ../textgen-portable-*.zip
|
||||||
tag: ${{ inputs.version }}
|
tag: ${{ inputs.version }}
|
||||||
file_glob: true
|
file_glob: true
|
||||||
make_latest: false
|
make_latest: false
|
||||||
|
|
|
||||||
162
README.md
162
README.md
|
|
@ -13,7 +13,7 @@
|
||||||
|
|
||||||
# Text Generation Web UI
|
# Text Generation Web UI
|
||||||
|
|
||||||
A Gradio web UI for running Large Language Models locally. 100% private and offline. Supports text generation, vision, tool-calling, training, image generation, and more.
|
A Gradio web UI for running Large Language Models locally. 100% private, offline, and free.
|
||||||
|
|
||||||
[Try the Deep Reason extension](https://oobabooga.gumroad.com/l/deep_reason)
|
[Try the Deep Reason extension](https://oobabooga.gumroad.com/l/deep_reason)
|
||||||
|
|
||||||
|
|
@ -23,20 +23,22 @@ A Gradio web UI for running Large Language Models locally. 100% private and offl
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- **Easy setup**: [Portable builds](https://github.com/oobabooga/text-generation-webui/releases) (zero setup, just unzip and run) for GGUF models on Windows/Linux/macOS, or a one-click installer for the full feature set.
|
- Supports multiple local text generation backends, including [llama.cpp](https://github.com/ggerganov/llama.cpp), [Transformers](https://github.com/huggingface/transformers), [ExLlamaV3](https://github.com/turboderp-org/exllamav3), and [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) (the latter via its own [Dockerfile](https://github.com/oobabooga/text-generation-webui/blob/main/docker/TensorRT-LLM/Dockerfile)).
|
||||||
- **Multiple backends**: [llama.cpp](https://github.com/ggerganov/llama.cpp), [ik_llama.cpp](https://github.com/ikawrakow/ik_llama.cpp), [Transformers](https://github.com/huggingface/transformers), [ExLlamaV3](https://github.com/turboderp-org/exllamav3), and [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM). Switch between backends and models without restarting.
|
- Easy setup: Choose between **portable builds** (zero setup, just unzip and run) for GGUF models on Windows/Linux/macOS, or the one-click installer that creates a self-contained `installer_files` directory.
|
||||||
- **OpenAI/Anthropic-compatible API**: Chat, Completions, and Messages endpoints with tool-calling support. Use as a local drop-in replacement for the OpenAI/Anthropic APIs ([examples](https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples)).
|
|
||||||
- **Tool-calling**: Models can call custom functions during chat — web search, page fetching, math, and more. Each tool is a single `.py` file. MCP servers are also supported ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Tool-Calling-Tutorial)).
|
|
||||||
- **Vision (multimodal)**: Attach images to messages for visual understanding ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Multimodal-Tutorial)).
|
|
||||||
- **File attachments**: Upload text files, PDF documents, and .docx documents to talk about their contents.
|
|
||||||
- **Training**: Fine-tune LoRAs on multi-turn chat or raw text datasets. Supports resuming interrupted runs ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/05-%E2%80%90-Training-Tab)).
|
|
||||||
- **Image generation**: A dedicated tab for `diffusers` models like **Z-Image-Turbo**. Features 4-bit/8-bit quantization and a persistent gallery with metadata ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Image-Generation-Tutorial)).
|
|
||||||
- 100% offline and private, with zero telemetry, external resources, or remote update requests.
|
- 100% offline and private, with zero telemetry, external resources, or remote update requests.
|
||||||
- `instruct` mode for instruction-following (like ChatGPT), and `chat-instruct`/`chat` modes for talking to custom characters. Prompts are automatically formatted with Jinja2 templates.
|
- `instruct` mode for instruction-following (like ChatGPT), and `chat-instruct`/`chat` modes for talking to custom characters. Prompts are automatically formatted with Jinja2 templates.
|
||||||
|
- **File attachments**: Upload text files, PDF documents, and .docx documents to talk about their contents.
|
||||||
|
- **Vision (multimodal models)**: Attach images to messages for visual understanding ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Multimodal-Tutorial)).
|
||||||
|
- **Image generation**: A dedicated tab for `diffusers` models like **Z-Image-Turbo**. Features 4-bit/8-bit quantization and a persistent gallery with metadata ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Image-Generation-Tutorial)).
|
||||||
|
- **Web search**: Optionally search the internet with LLM-generated queries to add context to the conversation.
|
||||||
|
- **Training**: Fine-tune LoRAs on multi-turn chat or raw text datasets. Easy to use, good defaults, and supports resuming interrupted runs ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/05-%E2%80%90-Training-Tab)).
|
||||||
- Edit messages, navigate between message versions, and branch conversations at any point.
|
- Edit messages, navigate between message versions, and branch conversations at any point.
|
||||||
|
- Switch between different models in the UI without restarting.
|
||||||
- Free-form text generation in the Notebook tab without being limited to chat turns.
|
- Free-form text generation in the Notebook tab without being limited to chat turns.
|
||||||
- Multiple sampling parameters and generation options for sophisticated text generation control.
|
- Multiple sampling parameters and generation options for sophisticated text generation control.
|
||||||
- Dark/light themes, syntax highlighting for code blocks, and LaTeX rendering for mathematical expressions.
|
- Aesthetic UI with dark and light themes.
|
||||||
|
- Syntax highlighting for code blocks and LaTeX rendering for mathematical expressions.
|
||||||
|
- OpenAI-compatible API with Chat and Completions endpoints, including tool-calling support – see [examples](https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples).
|
||||||
- Extension support, with numerous built-in and user-contributed extensions available. See the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/07-%E2%80%90-Extensions) and [extensions directory](https://github.com/oobabooga/text-generation-webui-extensions) for details.
|
- Extension support, with numerous built-in and user-contributed extensions available. See the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/07-%E2%80%90-Extensions) and [extensions directory](https://github.com/oobabooga/text-generation-webui-extensions) for details.
|
||||||
|
|
||||||
## How to install
|
## How to install
|
||||||
|
|
@ -45,10 +47,9 @@ A Gradio web UI for running Large Language Models locally. 100% private and offl
|
||||||
|
|
||||||
No installation needed – just download, unzip and run. All dependencies included.
|
No installation needed – just download, unzip and run. All dependencies included.
|
||||||
|
|
||||||
Download from here: **https://github.com/oobabooga/text-generation-webui/releases**
|
Compatible with GGUF (llama.cpp) models on Windows, Linux, and macOS. [Check what models fit your hardware](https://huggingface.co/spaces/oobabooga/accurate-gguf-vram-calculator).
|
||||||
|
|
||||||
- Builds are provided for Linux, Windows, and macOS, with options for CUDA, Vulkan, ROCm, and CPU-only.
|
Download from here: **https://github.com/oobabooga/text-generation-webui/releases**
|
||||||
- Compatible with GGUF (llama.cpp) models.
|
|
||||||
|
|
||||||
#### Option 2: Manual portable install with venv
|
#### Option 2: Manual portable install with venv
|
||||||
|
|
||||||
|
|
@ -80,7 +81,7 @@ deactivate
|
||||||
|
|
||||||
#### Option 3: One-click installer
|
#### Option 3: One-click installer
|
||||||
|
|
||||||
For users who need additional backends (ExLlamaV3, Transformers), training, image generation, or extensions (TTS, voice input, translation, etc). Requires ~10GB disk space and downloads PyTorch.
|
For users who need additional backends (ExLlamaV3, Transformers) or extensions (TTS, voice input, translation, etc). Requires ~10GB disk space and downloads PyTorch.
|
||||||
|
|
||||||
1. Clone the repository, or [download its source code](https://github.com/oobabooga/text-generation-webui/archive/refs/heads/main.zip) and extract it.
|
1. Clone the repository, or [download its source code](https://github.com/oobabooga/text-generation-webui/archive/refs/heads/main.zip) and extract it.
|
||||||
2. Run the startup script for your OS: `start_windows.bat`, `start_linux.sh`, or `start_macos.sh`.
|
2. Run the startup script for your OS: `start_windows.bat`, `start_linux.sh`, or `start_macos.sh`.
|
||||||
|
|
@ -145,7 +146,7 @@ conda activate textgen
|
||||||
|--------|---------|---------|
|
|--------|---------|---------|
|
||||||
| Linux/WSL | NVIDIA | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128` |
|
| Linux/WSL | NVIDIA | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128` |
|
||||||
| Linux/WSL | CPU only | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cpu` |
|
| Linux/WSL | CPU only | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cpu` |
|
||||||
| Linux | AMD | `pip3 install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.9.1%2Brocm7.2.0.lw.git7e1940d4-cp313-cp313-linux_x86_64.whl` |
|
| Linux | AMD | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/rocm6.4` |
|
||||||
| MacOS + MPS | Any | `pip3 install torch==2.9.1` |
|
| MacOS + MPS | Any | `pip3 install torch==2.9.1` |
|
||||||
| Windows | NVIDIA | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128` |
|
| Windows | NVIDIA | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128` |
|
||||||
| Windows | CPU only | `pip3 install torch==2.9.1` |
|
| Windows | CPU only | `pip3 install torch==2.9.1` |
|
||||||
|
|
@ -200,7 +201,7 @@ ln -s docker/{nvidia/Dockerfile,nvidia/docker-compose.yml,.dockerignore} .
|
||||||
For AMD GPU:
|
For AMD GPU:
|
||||||
ln -s docker/{amd/Dockerfile,amd/docker-compose.yml,.dockerignore} .
|
ln -s docker/{amd/Dockerfile,amd/docker-compose.yml,.dockerignore} .
|
||||||
For Intel GPU:
|
For Intel GPU:
|
||||||
ln -s docker/{intel/Dockerfile,intel/docker-compose.yml,.dockerignore} .
|
ln -s docker/{intel/Dockerfile,amd/docker-compose.yml,.dockerignore} .
|
||||||
For CPU only
|
For CPU only
|
||||||
ln -s docker/{cpu/Dockerfile,cpu/docker-compose.yml,.dockerignore} .
|
ln -s docker/{cpu/Dockerfile,cpu/docker-compose.yml,.dockerignore} .
|
||||||
cp docker/.env.example .env
|
cp docker/.env.example .env
|
||||||
|
|
@ -235,24 +236,20 @@ List of command-line flags
|
||||||
</summary>
|
</summary>
|
||||||
|
|
||||||
```txt
|
```txt
|
||||||
usage: server.py [-h] [--user-data-dir USER_DATA_DIR] [--multi-user] [--model MODEL] [--lora LORA [LORA ...]] [--model-dir MODEL_DIR] [--lora-dir LORA_DIR] [--model-menu] [--settings SETTINGS]
|
usage: server.py [-h] [--multi-user] [--model MODEL] [--lora LORA [LORA ...]] [--model-dir MODEL_DIR] [--lora-dir LORA_DIR] [--model-menu] [--settings SETTINGS]
|
||||||
[--extensions EXTENSIONS [EXTENSIONS ...]] [--verbose] [--idle-timeout IDLE_TIMEOUT] [--image-model IMAGE_MODEL] [--image-model-dir IMAGE_MODEL_DIR] [--image-dtype {bfloat16,float16}]
|
[--extensions EXTENSIONS [EXTENSIONS ...]] [--verbose] [--idle-timeout IDLE_TIMEOUT] [--image-model IMAGE_MODEL] [--image-model-dir IMAGE_MODEL_DIR] [--image-dtype {bfloat16,float16}]
|
||||||
[--image-attn-backend {flash_attention_2,sdpa}] [--image-cpu-offload] [--image-compile] [--image-quant {none,bnb-8bit,bnb-4bit,torchao-int8wo,torchao-fp4,torchao-float8wo}]
|
[--image-attn-backend {flash_attention_2,sdpa}] [--image-cpu-offload] [--image-compile] [--image-quant {none,bnb-8bit,bnb-4bit,torchao-int8wo,torchao-fp4,torchao-float8wo}]
|
||||||
[--loader LOADER] [--ctx-size N] [--cache-type N] [--model-draft MODEL_DRAFT] [--draft-max DRAFT_MAX] [--gpu-layers-draft GPU_LAYERS_DRAFT] [--device-draft DEVICE_DRAFT]
|
[--loader LOADER] [--ctx-size N] [--cache-type N] [--model-draft MODEL_DRAFT] [--draft-max DRAFT_MAX] [--gpu-layers-draft GPU_LAYERS_DRAFT] [--device-draft DEVICE_DRAFT]
|
||||||
[--ctx-size-draft CTX_SIZE_DRAFT] [--spec-type {none,ngram-mod,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-cache}] [--spec-ngram-size-n SPEC_NGRAM_SIZE_N]
|
[--ctx-size-draft CTX_SIZE_DRAFT] [--spec-type {none,ngram-mod,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-cache}] [--spec-ngram-size-n SPEC_NGRAM_SIZE_N]
|
||||||
[--spec-ngram-size-m SPEC_NGRAM_SIZE_M] [--spec-ngram-min-hits SPEC_NGRAM_MIN_HITS] [--gpu-layers N] [--cpu-moe] [--mmproj MMPROJ] [--streaming-llm] [--tensor-split TENSOR_SPLIT]
|
[--spec-ngram-size-m SPEC_NGRAM_SIZE_M] [--spec-ngram-min-hits SPEC_NGRAM_MIN_HITS] [--gpu-layers N] [--cpu-moe] [--mmproj MMPROJ] [--streaming-llm] [--tensor-split TENSOR_SPLIT]
|
||||||
[--row-split] [--no-mmap] [--mlock] [--no-kv-offload] [--batch-size BATCH_SIZE] [--ubatch-size UBATCH_SIZE] [--threads THREADS] [--threads-batch THREADS_BATCH] [--numa]
|
[--row-split] [--no-mmap] [--mlock] [--no-kv-offload] [--batch-size BATCH_SIZE] [--ubatch-size UBATCH_SIZE] [--threads THREADS] [--threads-batch THREADS_BATCH] [--numa]
|
||||||
[--parallel PARALLEL] [--fit-target FIT_TARGET] [--extra-flags EXTRA_FLAGS] [--cpu] [--cpu-memory CPU_MEMORY] [--disk] [--disk-cache-dir DISK_CACHE_DIR] [--load-in-8bit] [--bf16]
|
[--extra-flags EXTRA_FLAGS] [--cpu] [--cpu-memory CPU_MEMORY] [--disk] [--disk-cache-dir DISK_CACHE_DIR] [--load-in-8bit] [--bf16] [--no-cache] [--trust-remote-code]
|
||||||
[--no-cache] [--trust-remote-code] [--force-safetensors] [--no_use_fast] [--attn-implementation IMPLEMENTATION] [--load-in-4bit] [--use_double_quant] [--compute_dtype COMPUTE_DTYPE]
|
[--force-safetensors] [--no_use_fast] [--attn-implementation IMPLEMENTATION] [--load-in-4bit] [--use_double_quant] [--compute_dtype COMPUTE_DTYPE] [--quant_type QUANT_TYPE]
|
||||||
[--quant_type QUANT_TYPE] [--gpu-split GPU_SPLIT] [--enable-tp] [--tp-backend TP_BACKEND] [--cfg-cache] [--listen] [--listen-port LISTEN_PORT] [--listen-host LISTEN_HOST] [--share]
|
[--gpu-split GPU_SPLIT] [--enable-tp] [--tp-backend TP_BACKEND] [--cfg-cache] [--cpp-runner]
|
||||||
|
[--alpha_value ALPHA_VALUE] [--rope_freq_base ROPE_FREQ_BASE] [--compress_pos_emb COMPRESS_POS_EMB] [--listen] [--listen-port LISTEN_PORT] [--listen-host LISTEN_HOST] [--share]
|
||||||
[--auto-launch] [--gradio-auth GRADIO_AUTH] [--gradio-auth-path GRADIO_AUTH_PATH] [--ssl-keyfile SSL_KEYFILE] [--ssl-certfile SSL_CERTFILE] [--subpath SUBPATH] [--old-colors]
|
[--auto-launch] [--gradio-auth GRADIO_AUTH] [--gradio-auth-path GRADIO_AUTH_PATH] [--ssl-keyfile SSL_KEYFILE] [--ssl-certfile SSL_CERTFILE] [--subpath SUBPATH] [--old-colors]
|
||||||
[--portable] [--api] [--public-api] [--public-api-id PUBLIC_API_ID] [--api-port API_PORT] [--api-key API_KEY] [--admin-key ADMIN_KEY] [--api-enable-ipv6] [--api-disable-ipv4]
|
[--portable] [--api] [--public-api] [--public-api-id PUBLIC_API_ID] [--api-port API_PORT] [--api-key API_KEY] [--admin-key ADMIN_KEY] [--api-enable-ipv6] [--api-disable-ipv4]
|
||||||
[--nowebui] [--temperature N] [--dynatemp-low N] [--dynatemp-high N] [--dynatemp-exponent N] [--smoothing-factor N] [--smoothing-curve N] [--min-p N] [--top-p N] [--top-k N]
|
[--nowebui]
|
||||||
[--typical-p N] [--xtc-threshold N] [--xtc-probability N] [--epsilon-cutoff N] [--eta-cutoff N] [--tfs N] [--top-a N] [--top-n-sigma N] [--adaptive-target N] [--adaptive-decay N]
|
|
||||||
[--dry-multiplier N] [--dry-allowed-length N] [--dry-base N] [--repetition-penalty N] [--frequency-penalty N] [--presence-penalty N] [--encoder-repetition-penalty N]
|
|
||||||
[--no-repeat-ngram-size N] [--repetition-penalty-range N] [--penalty-alpha N] [--guidance-scale N] [--mirostat-mode N] [--mirostat-tau N] [--mirostat-eta N]
|
|
||||||
[--do-sample | --no-do-sample] [--dynamic-temperature | --no-dynamic-temperature] [--temperature-last | --no-temperature-last] [--sampler-priority N] [--dry-sequence-breakers N]
|
|
||||||
[--enable-thinking | --no-enable-thinking] [--reasoning-effort N] [--chat-template-file CHAT_TEMPLATE_FILE]
|
|
||||||
|
|
||||||
Text Generation Web UI
|
Text Generation Web UI
|
||||||
|
|
||||||
|
|
@ -260,8 +257,7 @@ options:
|
||||||
-h, --help show this help message and exit
|
-h, --help show this help message and exit
|
||||||
|
|
||||||
Basic settings:
|
Basic settings:
|
||||||
--user-data-dir USER_DATA_DIR Path to the user data directory. Default: auto-detected.
|
--multi-user Multi-user mode. Chat histories are not saved or automatically loaded. Warning: this is likely not safe for sharing publicly.
|
||||||
--multi-user Multi-user mode. Chat histories are not saved or automatically loaded. Best suited for small trusted teams.
|
|
||||||
--model MODEL Name of the model to load by default.
|
--model MODEL Name of the model to load by default.
|
||||||
--lora LORA [LORA ...] The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.
|
--lora LORA [LORA ...] The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.
|
||||||
--model-dir MODEL_DIR Path to directory with all the models.
|
--model-dir MODEL_DIR Path to directory with all the models.
|
||||||
|
|
@ -284,12 +280,12 @@ Image model:
|
||||||
Quantization method for image model.
|
Quantization method for image model.
|
||||||
|
|
||||||
Model loader:
|
Model loader:
|
||||||
--loader LOADER Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav3, TensorRT-
|
--loader LOADER Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav3,
|
||||||
LLM.
|
TensorRT-LLM.
|
||||||
|
|
||||||
Context and cache:
|
Context and cache:
|
||||||
--ctx-size, --n_ctx, --max_seq_len N Context size in tokens. 0 = auto for llama.cpp (requires gpu-layers=-1), 8192 for other loaders.
|
--ctx-size N, --n_ctx N, --max_seq_len N Context size in tokens. llama.cpp: 0 = auto if gpu-layers is also -1.
|
||||||
--cache-type, --cache_type N KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8).
|
--cache-type N, --cache_type N KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8).
|
||||||
|
|
||||||
Speculative decoding:
|
Speculative decoding:
|
||||||
--model-draft MODEL_DRAFT Path to the draft model for speculative decoding.
|
--model-draft MODEL_DRAFT Path to the draft model for speculative decoding.
|
||||||
|
|
@ -304,7 +300,7 @@ Speculative decoding:
|
||||||
--spec-ngram-min-hits SPEC_NGRAM_MIN_HITS Minimum n-gram hits for ngram-map speculative decoding.
|
--spec-ngram-min-hits SPEC_NGRAM_MIN_HITS Minimum n-gram hits for ngram-map speculative decoding.
|
||||||
|
|
||||||
llama.cpp:
|
llama.cpp:
|
||||||
--gpu-layers, --n-gpu-layers N Number of layers to offload to the GPU. -1 = auto.
|
--gpu-layers N, --n-gpu-layers N Number of layers to offload to the GPU. -1 = auto.
|
||||||
--cpu-moe Move the experts to the CPU (for MoE models).
|
--cpu-moe Move the experts to the CPU (for MoE models).
|
||||||
--mmproj MMPROJ Path to the mmproj file for vision models.
|
--mmproj MMPROJ Path to the mmproj file for vision models.
|
||||||
--streaming-llm Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.
|
--streaming-llm Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.
|
||||||
|
|
@ -312,23 +308,19 @@ llama.cpp:
|
||||||
--row-split Split the model by rows across GPUs. This may improve multi-gpu performance.
|
--row-split Split the model by rows across GPUs. This may improve multi-gpu performance.
|
||||||
--no-mmap Prevent mmap from being used.
|
--no-mmap Prevent mmap from being used.
|
||||||
--mlock Force the system to keep the model in RAM.
|
--mlock Force the system to keep the model in RAM.
|
||||||
--no-kv-offload Do not offload the K, Q, V to the GPU. This saves VRAM but reduces performance.
|
--no-kv-offload Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.
|
||||||
--batch-size BATCH_SIZE Maximum number of prompt tokens to batch together when calling llama-server. This is the application level batch size.
|
--batch-size BATCH_SIZE Maximum number of prompt tokens to batch together when calling llama-server. This is the application level batch size.
|
||||||
--ubatch-size UBATCH_SIZE Maximum number of prompt tokens to batch together when calling llama-server. This is the max physical batch size for computation (device level).
|
--ubatch-size UBATCH_SIZE Maximum number of prompt tokens to batch together when calling llama-server. This is the max physical batch size for computation (device level).
|
||||||
--threads THREADS Number of threads to use.
|
--threads THREADS Number of threads to use.
|
||||||
--threads-batch THREADS_BATCH Number of threads to use for batches/prompt processing.
|
--threads-batch THREADS_BATCH Number of threads to use for batches/prompt processing.
|
||||||
--numa Activate NUMA task allocation for llama.cpp.
|
--numa Activate NUMA task allocation for llama.cpp.
|
||||||
--parallel PARALLEL Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set
|
|
||||||
ctx_size to 32768.
|
|
||||||
--fit-target FIT_TARGET Target VRAM margin per device for auto GPU layers, comma-separated list of values in MiB. A single value is broadcast across all devices.
|
|
||||||
Default: 1024.
|
|
||||||
--extra-flags EXTRA_FLAGS Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"
|
--extra-flags EXTRA_FLAGS Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"
|
||||||
|
|
||||||
Transformers/Accelerate:
|
Transformers/Accelerate:
|
||||||
--cpu Use the CPU to generate text. Warning: Training on CPU is extremely slow.
|
--cpu Use the CPU to generate text. Warning: Training on CPU is extremely slow.
|
||||||
--cpu-memory CPU_MEMORY Maximum CPU memory in GiB. Use this for CPU offloading.
|
--cpu-memory CPU_MEMORY Maximum CPU memory in GiB. Use this for CPU offloading.
|
||||||
--disk If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.
|
--disk If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.
|
||||||
--disk-cache-dir DISK_CACHE_DIR Directory to save the disk cache to.
|
--disk-cache-dir DISK_CACHE_DIR Directory to save the disk cache to. Defaults to "user_data/cache".
|
||||||
--load-in-8bit Load the model with 8-bit precision (using bitsandbytes).
|
--load-in-8bit Load the model with 8-bit precision (using bitsandbytes).
|
||||||
--bf16 Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
|
--bf16 Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
|
||||||
--no-cache Set use_cache to False while generating text. This reduces VRAM usage slightly, but it comes at a performance cost.
|
--no-cache Set use_cache to False while generating text. This reduces VRAM usage slightly, but it comes at a performance cost.
|
||||||
|
|
@ -349,6 +341,14 @@ ExLlamaV3:
|
||||||
--tp-backend TP_BACKEND The backend for tensor parallelism. Valid options: native, nccl. Default: native.
|
--tp-backend TP_BACKEND The backend for tensor parallelism. Valid options: native, nccl. Default: native.
|
||||||
--cfg-cache Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.
|
--cfg-cache Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.
|
||||||
|
|
||||||
|
TensorRT-LLM:
|
||||||
|
--cpp-runner Use the ModelRunnerCpp runner, which is faster than the default ModelRunner.
|
||||||
|
|
||||||
|
RoPE:
|
||||||
|
--alpha_value ALPHA_VALUE Positional embeddings alpha factor for NTK RoPE scaling. Use either this or compress_pos_emb, not both.
|
||||||
|
--rope_freq_base ROPE_FREQ_BASE If greater than 0, will be used instead of alpha_value. Those two are related by rope_freq_base = 10000 * alpha_value ^ (64 / 63).
|
||||||
|
--compress_pos_emb COMPRESS_POS_EMB Positional embeddings compression factor. Should be set to (context length) / (model's original context length). Equal to 1/rope_freq_scale.
|
||||||
|
|
||||||
Gradio:
|
Gradio:
|
||||||
--listen Make the web UI reachable from your local network.
|
--listen Make the web UI reachable from your local network.
|
||||||
--listen-port LISTEN_PORT The listening port that the server will use.
|
--listen-port LISTEN_PORT The listening port that the server will use.
|
||||||
|
|
@ -365,7 +365,7 @@ Gradio:
|
||||||
|
|
||||||
API:
|
API:
|
||||||
--api Enable the API extension.
|
--api Enable the API extension.
|
||||||
--public-api Create a public URL for the API using Cloudflare.
|
--public-api Create a public URL for the API using Cloudfare.
|
||||||
--public-api-id PUBLIC_API_ID Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.
|
--public-api-id PUBLIC_API_ID Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.
|
||||||
--api-port API_PORT The listening port for the API.
|
--api-port API_PORT The listening port for the API.
|
||||||
--api-key API_KEY API authentication key.
|
--api-key API_KEY API authentication key.
|
||||||
|
|
@ -373,67 +373,28 @@ API:
|
||||||
--api-enable-ipv6 Enable IPv6 for the API
|
--api-enable-ipv6 Enable IPv6 for the API
|
||||||
--api-disable-ipv4 Disable IPv4 for the API
|
--api-disable-ipv4 Disable IPv4 for the API
|
||||||
--nowebui Do not launch the Gradio UI. Useful for launching the API in standalone mode.
|
--nowebui Do not launch the Gradio UI. Useful for launching the API in standalone mode.
|
||||||
|
|
||||||
API generation defaults:
|
|
||||||
--temperature N Temperature
|
|
||||||
--dynatemp-low N Dynamic temperature low
|
|
||||||
--dynatemp-high N Dynamic temperature high
|
|
||||||
--dynatemp-exponent N Dynamic temperature exponent
|
|
||||||
--smoothing-factor N Smoothing factor
|
|
||||||
--smoothing-curve N Smoothing curve
|
|
||||||
--min-p N Min P
|
|
||||||
--top-p N Top P
|
|
||||||
--top-k N Top K
|
|
||||||
--typical-p N Typical P
|
|
||||||
--xtc-threshold N XTC threshold
|
|
||||||
--xtc-probability N XTC probability
|
|
||||||
--epsilon-cutoff N Epsilon cutoff
|
|
||||||
--eta-cutoff N Eta cutoff
|
|
||||||
--tfs N TFS
|
|
||||||
--top-a N Top A
|
|
||||||
--top-n-sigma N Top N Sigma
|
|
||||||
--adaptive-target N Adaptive target
|
|
||||||
--adaptive-decay N Adaptive decay
|
|
||||||
--dry-multiplier N DRY multiplier
|
|
||||||
--dry-allowed-length N DRY allowed length
|
|
||||||
--dry-base N DRY base
|
|
||||||
--repetition-penalty N Repetition penalty
|
|
||||||
--frequency-penalty N Frequency penalty
|
|
||||||
--presence-penalty N Presence penalty
|
|
||||||
--encoder-repetition-penalty N Encoder repetition penalty
|
|
||||||
--no-repeat-ngram-size N No repeat ngram size
|
|
||||||
--repetition-penalty-range N Repetition penalty range
|
|
||||||
--penalty-alpha N Penalty alpha
|
|
||||||
--guidance-scale N Guidance scale
|
|
||||||
--mirostat-mode N Mirostat mode
|
|
||||||
--mirostat-tau N Mirostat tau
|
|
||||||
--mirostat-eta N Mirostat eta
|
|
||||||
--do-sample, --no-do-sample Do sample
|
|
||||||
--dynamic-temperature, --no-dynamic-temperature Dynamic temperature
|
|
||||||
--temperature-last, --no-temperature-last Temperature last
|
|
||||||
--sampler-priority N Sampler priority
|
|
||||||
--dry-sequence-breakers N DRY sequence breakers
|
|
||||||
--enable-thinking, --no-enable-thinking Enable thinking
|
|
||||||
--reasoning-effort N Reasoning effort
|
|
||||||
--chat-template-file CHAT_TEMPLATE_FILE Path to a chat template file (.jinja, .jinja2, or .yaml) to use as the default instruction template for API requests. Overrides the model's
|
|
||||||
built-in template.
|
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
## Downloading models
|
## Downloading models
|
||||||
|
|
||||||
1. Download a GGUF model file from [Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads&search=gguf).
|
Models should be placed in the folder `text-generation-webui/user_data/models`. They are usually downloaded from [Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads&search=gguf).
|
||||||
2. Place it in the `user_data/models` folder.
|
|
||||||
|
|
||||||
That's it. The UI will detect it automatically.
|
To check if a GGUF model will fit in your hardware before downloading it, you can use this tool I created:
|
||||||
|
|
||||||
To estimate how much memory a model will use, you can use the [GGUF Memory Calculator](https://huggingface.co/spaces/oobabooga/accurate-gguf-vram-calculator).
|
[Accurate GGUF VRAM Calculator](https://huggingface.co/spaces/oobabooga/accurate-gguf-vram-calculator)
|
||||||
|
|
||||||
<details>
|
* GGUF models are a single file and should be placed directly into `user_data/models`. Example:
|
||||||
<summary>Other model types (Transformers, EXL3)</summary>
|
|
||||||
|
|
||||||
Models that consist of multiple files (like 16-bit Transformers models and EXL3 models) should be placed in a subfolder inside `user_data/models`:
|
```
|
||||||
|
text-generation-webui
|
||||||
|
└── user_data
|
||||||
|
└── models
|
||||||
|
└── llama-2-13b-chat.Q4_K_M.gguf
|
||||||
|
```
|
||||||
|
|
||||||
|
* The remaining model types (like 16-bit Transformers models and EXL3 models) are made of several files and must be placed in a subfolder. Example:
|
||||||
|
|
||||||
```
|
```
|
||||||
text-generation-webui
|
text-generation-webui
|
||||||
|
|
@ -443,18 +404,31 @@ text-generation-webui
|
||||||
├── config.json
|
├── config.json
|
||||||
├── generation_config.json
|
├── generation_config.json
|
||||||
├── model-00001-of-00004.safetensors
|
├── model-00001-of-00004.safetensors
|
||||||
├── ...
|
├── model-00002-of-00004.safetensors
|
||||||
|
├── model-00003-of-00004.safetensors
|
||||||
|
├── model-00004-of-00004.safetensors
|
||||||
|
├── model.safetensors.index.json
|
||||||
|
├── special_tokens_map.json
|
||||||
├── tokenizer_config.json
|
├── tokenizer_config.json
|
||||||
└── tokenizer.json
|
└── tokenizer.json
|
||||||
```
|
```
|
||||||
|
|
||||||
These formats require the one-click installer (not the portable build).
|
In both cases, you can use the "Model" tab of the UI to download the model from Hugging Face automatically. It is also possible to download it via the command-line with:
|
||||||
</details>
|
|
||||||
|
```
|
||||||
|
python download-model.py organization/model
|
||||||
|
```
|
||||||
|
|
||||||
|
Run `python download-model.py --help` to see all the options.
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
https://github.com/oobabooga/text-generation-webui/wiki
|
https://github.com/oobabooga/text-generation-webui/wiki
|
||||||
|
|
||||||
|
## Google Colab notebook
|
||||||
|
|
||||||
|
https://colab.research.google.com/github/oobabooga/text-generation-webui/blob/main/Colab-TextGen-GPU.ipynb
|
||||||
|
|
||||||
## Community
|
## Community
|
||||||
|
|
||||||
https://www.reddit.com/r/Oobabooga/
|
https://www.reddit.com/r/Oobabooga/
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,6 @@ set INSTALL_ENV_DIR=%cd%\installer_files\env
|
||||||
set PYTHONNOUSERSITE=1
|
set PYTHONNOUSERSITE=1
|
||||||
set PYTHONPATH=
|
set PYTHONPATH=
|
||||||
set PYTHONHOME=
|
set PYTHONHOME=
|
||||||
set PYTHONUTF8=1
|
|
||||||
set "CUDA_PATH=%INSTALL_ENV_DIR%"
|
set "CUDA_PATH=%INSTALL_ENV_DIR%"
|
||||||
set "CUDA_HOME=%CUDA_PATH%"
|
set "CUDA_HOME=%CUDA_PATH%"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
display: grid;
|
display: grid;
|
||||||
align-items: start;
|
align-items: start;
|
||||||
grid-template-columns: 60px minmax(0, 1fr);
|
grid-template-columns: 60px minmax(0, 1fr);
|
||||||
width: min(100%, calc(724px + 60px));
|
|
||||||
padding-bottom: 22px;
|
padding-bottom: 22px;
|
||||||
padding-top: 6px;
|
padding-top: 6px;
|
||||||
font-size: 18px;
|
font-size: 18px;
|
||||||
|
|
@ -92,6 +91,9 @@
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body p {
|
.message-body p {
|
||||||
|
margin-bottom: 0 !important;
|
||||||
|
font-size: 16px !important;
|
||||||
|
line-height: 1.5 !important;
|
||||||
color: #e0e0e0 !important; /* Light color for text */
|
color: #e0e0e0 !important; /* Light color for text */
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -120,7 +122,7 @@
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body p {
|
.message-body p {
|
||||||
font-size: 14px !important;
|
font-size: 14px !important; /* Smaller text for mobile */
|
||||||
}
|
}
|
||||||
|
|
||||||
.username {
|
.username {
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@
|
||||||
display: grid;
|
display: grid;
|
||||||
align-items: start;
|
align-items: start;
|
||||||
grid-template-columns: 60px minmax(0, 1fr);
|
grid-template-columns: 60px minmax(0, 1fr);
|
||||||
width: min(100%, calc(724px + 60px + 90px));
|
|
||||||
padding-bottom: 21px;
|
padding-bottom: 21px;
|
||||||
padding-top: 7px;
|
padding-top: 7px;
|
||||||
font-size: 18px;
|
font-size: 18px;
|
||||||
|
|
@ -87,8 +86,10 @@
|
||||||
border-radius: 20px;
|
border-radius: 20px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body p, .message-body li {
|
.message-body p {
|
||||||
|
margin-bottom: 0 !important;
|
||||||
font-size: 18px !important;
|
font-size: 18px !important;
|
||||||
|
line-height: 1.428571429 !important;
|
||||||
color: rgb(243 244 246) !important;
|
color: rgb(243 244 246) !important;
|
||||||
text-shadow: 2px 2px 2px rgb(0 0 0);
|
text-shadow: 2px 2px 2px rgb(0 0 0);
|
||||||
font-weight: 500;
|
font-weight: 500;
|
||||||
|
|
@ -126,7 +127,7 @@
|
||||||
padding-left: 0;
|
padding-left: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body p, .message-body li {
|
.message-body p {
|
||||||
font-size: 16px !important;
|
font-size: 16px !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,5 +19,4 @@
|
||||||
padding-bottom: 1.5em;
|
padding-bottom: 1.5em;
|
||||||
padding-top: 0.5em;
|
padding-top: 0.5em;
|
||||||
grid-template-columns: 70px minmax(0, 1fr);
|
grid-template-columns: 70px minmax(0, 1fr);
|
||||||
width: min(100%, calc(724px + 70px));
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
display: grid;
|
display: grid;
|
||||||
align-items: start;
|
align-items: start;
|
||||||
grid-template-columns: 60px minmax(0, 1fr);
|
grid-template-columns: 60px minmax(0, 1fr);
|
||||||
width: min(100%, calc(724px + 60px));
|
|
||||||
padding-bottom: 1.5em;
|
padding-bottom: 1.5em;
|
||||||
padding-top: 0.5em;
|
padding-top: 0.5em;
|
||||||
font-size: 15px;
|
font-size: 15px;
|
||||||
|
|
@ -47,10 +46,16 @@
|
||||||
border-radius: 20px;
|
border-radius: 20px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body p, .message-body li {
|
.message-body p {
|
||||||
|
font-size: 15px !important;
|
||||||
|
line-height: 22.5px !important;
|
||||||
font-weight: 500;
|
font-weight: 500;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.message-body p, .chat .message-body ul, .chat .message-body ol {
|
||||||
|
margin-bottom: 10px !important;
|
||||||
|
}
|
||||||
|
|
||||||
.dark .message-body p em {
|
.dark .message-body p em {
|
||||||
color: rgb(138 138 138) !important;
|
color: rgb(138 138 138) !important;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
.message {
|
.message {
|
||||||
width: min(100%, calc(724px + 60px));
|
|
||||||
padding-bottom: 22px;
|
padding-bottom: 22px;
|
||||||
padding-top: 3px;
|
padding-top: 3px;
|
||||||
font-size: 15px;
|
font-size: 15px;
|
||||||
|
|
@ -61,10 +60,8 @@
|
||||||
text-align: right;
|
text-align: right;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark .circle-bot + .text div, .dark .circle-bot + .text *,
|
.dark .circle-bot + .text div, .dark .circle-bot + .text * {
|
||||||
.dark .chat .message .circle-bot + .text .message-body :is(h1, h2, h3, h4, h5, h6),
|
color: #000;
|
||||||
.dark .chat .message .circle-bot + .text .message-body a {
|
|
||||||
color: #000 !important;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.text {
|
.text {
|
||||||
|
|
@ -79,14 +76,19 @@
|
||||||
font-weight: bold;
|
font-weight: bold;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.message-body {
|
||||||
|
}
|
||||||
|
|
||||||
.message-body img {
|
.message-body img {
|
||||||
max-width: 300px;
|
max-width: 300px;
|
||||||
max-height: 300px;
|
max-height: 300px;
|
||||||
border-radius: 20px;
|
border-radius: 20px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body p, .message-body li {
|
.message-body p {
|
||||||
|
margin-bottom: 0 !important;
|
||||||
font-size: 15px !important;
|
font-size: 15px !important;
|
||||||
|
line-height: 1.428571429 !important;
|
||||||
font-weight: 500;
|
font-weight: 500;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
.message {
|
.message {
|
||||||
display: block;
|
display: block;
|
||||||
width: min(100%, 724px);
|
|
||||||
padding-top: 0;
|
padding-top: 0;
|
||||||
padding-bottom: 21px;
|
padding-bottom: 21px;
|
||||||
font-size: 15px;
|
font-size: 15px;
|
||||||
|
|
@ -78,8 +77,14 @@
|
||||||
border-radius: 12px;
|
border-radius: 12px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body p, .message-body li {
|
.message-body p {
|
||||||
font-size: 15px !important;
|
font-size: 15px !important;
|
||||||
|
line-height: 1.4 !important;
|
||||||
|
font-weight: 400;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-body p:first-child {
|
||||||
|
margin-top: 0 !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark .message-body p em {
|
.dark .message-body p em {
|
||||||
|
|
@ -95,3 +100,6 @@
|
||||||
margin-top: 8px;
|
margin-top: 8px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.message-body p, .chat .message-body ul, .chat .message-body ol {
|
||||||
|
margin-bottom: 10px !important;
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,12 +13,21 @@
|
||||||
line-height: 28px !important;
|
line-height: 28px !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark .chat .message-body :is(p,li),
|
.dark .chat .message-body :is(p,li,h1,h2,h3,h4,h5,h6),
|
||||||
.dark .chat .message-body em:not(:is(h1,h2,h3,h4,h5,h6,b,strong) em),
|
.dark .chat .message-body em:not(:is(h1,h2,h3,h4,h5,h6,b,strong) em),
|
||||||
.dark .chat .message-body q:not(:is(h1,h2,h3,h4,h5,h6,b,strong) q) {
|
.dark .chat .message-body q:not(:is(h1,h2,h3,h4,h5,h6,b,strong) q) {
|
||||||
color: #d1d5db !important;
|
color: #d1d5db !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.chat .message-body :is(th, td),
|
||||||
|
.prose hr {
|
||||||
|
border-color: #40404096 !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.dark .chat .message-body :is(th, td),
|
||||||
|
.dark .prose hr {
|
||||||
|
border-color: rgb(255 255 255 / 30%) !important;
|
||||||
|
}
|
||||||
|
|
||||||
.chat .message-body :is(p, ul, ol) {
|
.chat .message-body :is(p, ul, ol) {
|
||||||
margin: 1.25em 0 !important;
|
margin: 1.25em 0 !important;
|
||||||
|
|
@ -69,7 +78,7 @@
|
||||||
|
|
||||||
.chat .user-message .text,
|
.chat .user-message .text,
|
||||||
.chat .assistant-message .text {
|
.chat .assistant-message .text {
|
||||||
max-width: 724px;
|
max-width: 700px;
|
||||||
margin-left: auto;
|
margin-left: auto;
|
||||||
margin-right: auto;
|
margin-right: auto;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
387
css/main.css
387
css/main.css
|
|
@ -2,8 +2,8 @@
|
||||||
--darker-gray: #1C1C1D;
|
--darker-gray: #1C1C1D;
|
||||||
--dark-gray: #212125;
|
--dark-gray: #212125;
|
||||||
--light-gray: #2C2E34;
|
--light-gray: #2C2E34;
|
||||||
--light-theme-gray: #f0f3fb;
|
--light-theme-gray: #f9fbff;
|
||||||
--border-color-dark: rgba(255, 255, 255, 0.15);
|
--border-color-dark: #525252;
|
||||||
--header-width: 112px;
|
--header-width: 112px;
|
||||||
--selected-item-color-dark: #282930;
|
--selected-item-color-dark: #282930;
|
||||||
}
|
}
|
||||||
|
|
@ -22,17 +22,6 @@
|
||||||
font-style: italic;
|
font-style: italic;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Hide spin buttons on number inputs (look bad on Windows) */
|
|
||||||
input[type="number"]::-webkit-outer-spin-button,
|
|
||||||
input[type="number"]::-webkit-inner-spin-button {
|
|
||||||
-webkit-appearance: none;
|
|
||||||
margin: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
input[type="number"] {
|
|
||||||
-moz-appearance: textfield;
|
|
||||||
}
|
|
||||||
|
|
||||||
.padded.svelte-12cmxck {
|
.padded.svelte-12cmxck {
|
||||||
padding: 3px 0;
|
padding: 3px 0;
|
||||||
}
|
}
|
||||||
|
|
@ -65,7 +54,7 @@ div.svelte-iyf88w {
|
||||||
height: 39.594px;
|
height: 39.594px;
|
||||||
align-self: end;
|
align-self: end;
|
||||||
line-height: 1em;
|
line-height: 1em;
|
||||||
border-radius: 0.75rem;
|
border-radius: 0.375rem;
|
||||||
flex: none;
|
flex: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -138,7 +127,7 @@ gradio-app > :first-child {
|
||||||
}
|
}
|
||||||
|
|
||||||
.header_bar {
|
.header_bar {
|
||||||
border-right: none;
|
border-right: var(--input-border-width) solid var(--input-border-color);
|
||||||
margin-bottom: 0;
|
margin-bottom: 0;
|
||||||
overflow-x: scroll;
|
overflow-x: scroll;
|
||||||
text-wrap: nowrap;
|
text-wrap: nowrap;
|
||||||
|
|
@ -161,7 +150,7 @@ gradio-app > :first-child {
|
||||||
|
|
||||||
.dark .header_bar {
|
.dark .header_bar {
|
||||||
border: none !important;
|
border: none !important;
|
||||||
box-shadow: none;
|
box-shadow: 0 3px 4px rgba(20 20 20 / 60%);
|
||||||
background-color: #8080802b;
|
background-color: #8080802b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -257,8 +246,8 @@ button {
|
||||||
|
|
||||||
.pretty_scrollbar::-webkit-scrollbar,
|
.pretty_scrollbar::-webkit-scrollbar,
|
||||||
#image-history-gallery > :nth-child(2)::-webkit-scrollbar {
|
#image-history-gallery > :nth-child(2)::-webkit-scrollbar {
|
||||||
width: 7px;
|
width: 8px;
|
||||||
height: 7px;
|
height: 8px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.pretty_scrollbar::-webkit-scrollbar-track,
|
.pretty_scrollbar::-webkit-scrollbar-track,
|
||||||
|
|
@ -271,7 +260,7 @@ button {
|
||||||
#image-history-gallery > :nth-child(2)::-webkit-scrollbar-thumb,
|
#image-history-gallery > :nth-child(2)::-webkit-scrollbar-thumb,
|
||||||
#image-history-gallery > :nth-child(2)::-webkit-scrollbar-thumb:hover {
|
#image-history-gallery > :nth-child(2)::-webkit-scrollbar-thumb:hover {
|
||||||
background: var(--neutral-300);
|
background: var(--neutral-300);
|
||||||
border-radius: 9999px;
|
border-radius: 30px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark .pretty_scrollbar::-webkit-scrollbar-thumb,
|
.dark .pretty_scrollbar::-webkit-scrollbar-thumb,
|
||||||
|
|
@ -279,17 +268,18 @@ button {
|
||||||
.dark #image-history-gallery > :nth-child(2)::-webkit-scrollbar-thumb,
|
.dark #image-history-gallery > :nth-child(2)::-webkit-scrollbar-thumb,
|
||||||
.dark #image-history-gallery > :nth-child(2)::-webkit-scrollbar-thumb:hover {
|
.dark #image-history-gallery > :nth-child(2)::-webkit-scrollbar-thumb:hover {
|
||||||
background: rgb(255 255 255 / 6.25%);
|
background: rgb(255 255 255 / 6.25%);
|
||||||
border-radius: 9999px;
|
border-radius: 10px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.pretty_scrollbar::-webkit-resizer,
|
.pretty_scrollbar::-webkit-resizer,
|
||||||
#image-history-gallery > :nth-child(2)::-webkit-resizer {
|
#image-history-gallery > :nth-child(2)::-webkit-resizer {
|
||||||
background: transparent;
|
background: #c5c5d2;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark .pretty_scrollbar::-webkit-resizer,
|
.dark .pretty_scrollbar::-webkit-resizer,
|
||||||
.dark #image-history-gallery > :nth-child(2)::-webkit-resizer {
|
.dark #image-history-gallery > :nth-child(2)::-webkit-resizer {
|
||||||
background: transparent;
|
background: #ccc;
|
||||||
|
border-radius: 10px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.pretty_scrollbar::-webkit-scrollbar-corner,
|
.pretty_scrollbar::-webkit-scrollbar-corner,
|
||||||
|
|
@ -410,6 +400,7 @@ audio {
|
||||||
}
|
}
|
||||||
|
|
||||||
.chat .message {
|
.chat .message {
|
||||||
|
width: min(100%, 48rem);
|
||||||
margin-left: auto;
|
margin-left: auto;
|
||||||
margin-right: auto;
|
margin-right: auto;
|
||||||
text-align: start;
|
text-align: start;
|
||||||
|
|
@ -440,31 +431,12 @@ audio {
|
||||||
font-size: 16px;
|
font-size: 16px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark .message-body h1,
|
.dark .message-body :is(h1, h2, h3, h4, h5, h6) {
|
||||||
.dark .message-body h2,
|
color: white !important;
|
||||||
.dark .message-body h3,
|
|
||||||
.dark .message-body h4,
|
|
||||||
.dark .message-body h5,
|
|
||||||
.dark .message-body h6 {
|
|
||||||
color: #e8e8e8 !important;
|
|
||||||
}
|
|
||||||
|
|
||||||
.message-body blockquote {
|
|
||||||
border-left-width: 4px;
|
|
||||||
border-left-color: var(--border-color-primary);
|
|
||||||
}
|
|
||||||
|
|
||||||
.message-body h1,
|
|
||||||
.message-body h2,
|
|
||||||
.message-body h3,
|
|
||||||
.message-body h4,
|
|
||||||
.message-body h5,
|
|
||||||
.message-body h6 {
|
|
||||||
color: #1a1a1a;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body h1 {
|
.message-body h1 {
|
||||||
font-weight: 700;
|
font-weight: 800;
|
||||||
font-size: 2.25em;
|
font-size: 2.25em;
|
||||||
margin-top: 0;
|
margin-top: 0;
|
||||||
margin-bottom: 0.8888889em;
|
margin-bottom: 0.8888889em;
|
||||||
|
|
@ -496,13 +468,13 @@ audio {
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body h5 {
|
.message-body h5 {
|
||||||
font-weight: 600;
|
font-weight: normal;
|
||||||
font-size: 1em;
|
font-size: 1em;
|
||||||
margin: 0;
|
margin: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body h6 {
|
.message-body h6 {
|
||||||
font-weight: 600;
|
font-weight: normal;
|
||||||
font-size: 1em;
|
font-size: 1em;
|
||||||
margin: 0;
|
margin: 0;
|
||||||
}
|
}
|
||||||
|
|
@ -602,28 +574,10 @@ audio {
|
||||||
|
|
||||||
#chat-input textarea {
|
#chat-input textarea {
|
||||||
background: #f3f4f6;
|
background: #f3f4f6;
|
||||||
padding: 0.675rem 2.5rem 0.6rem;
|
padding: 0.65rem 2.5rem;
|
||||||
margin-top: 0.15rem;
|
border: 0;
|
||||||
border: 1px solid #d2d2d8;
|
box-shadow: 0;
|
||||||
border-radius: 1.5rem;
|
border-radius: 8px;
|
||||||
overflow-y: auto !important;
|
|
||||||
}
|
|
||||||
|
|
||||||
#chat-input textarea::-webkit-scrollbar {
|
|
||||||
width: 7px;
|
|
||||||
}
|
|
||||||
|
|
||||||
#chat-input textarea::-webkit-scrollbar-track {
|
|
||||||
background: transparent;
|
|
||||||
}
|
|
||||||
|
|
||||||
#chat-input textarea::-webkit-scrollbar-thumb {
|
|
||||||
background: var(--neutral-300);
|
|
||||||
border-radius: 9999px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.dark #chat-input textarea::-webkit-scrollbar-thumb {
|
|
||||||
background: rgb(255 255 255 / 6.25%);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#chat-input textarea::placeholder {
|
#chat-input textarea::placeholder {
|
||||||
|
|
@ -653,10 +607,6 @@ audio {
|
||||||
background: transparent;
|
background: transparent;
|
||||||
}
|
}
|
||||||
|
|
||||||
#chat-input .thumbnails {
|
|
||||||
padding-top: 3px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.chat-input-positioned {
|
.chat-input-positioned {
|
||||||
max-width: 54rem;
|
max-width: 54rem;
|
||||||
left: 50%;
|
left: 50%;
|
||||||
|
|
@ -759,30 +709,7 @@ audio {
|
||||||
|
|
||||||
.hover-element {
|
.hover-element {
|
||||||
position: relative;
|
position: relative;
|
||||||
padding-top: 4px;
|
font-size: 24px;
|
||||||
}
|
|
||||||
|
|
||||||
#hover-element-button {
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
width: 32px;
|
|
||||||
height: 32px;
|
|
||||||
border-radius: 0.5rem;
|
|
||||||
cursor: pointer;
|
|
||||||
color: gray;
|
|
||||||
}
|
|
||||||
|
|
||||||
#hover-element-button:hover {
|
|
||||||
background-color: var(--background-fill-secondary);
|
|
||||||
}
|
|
||||||
|
|
||||||
#hover-element-button svg {
|
|
||||||
color: inherit;
|
|
||||||
}
|
|
||||||
|
|
||||||
.dark #hover-element-button:hover {
|
|
||||||
background-color: var(--selected-item-color-dark);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.hover-menu {
|
.hover-menu {
|
||||||
|
|
@ -790,40 +717,24 @@ audio {
|
||||||
position: absolute;
|
position: absolute;
|
||||||
bottom: 100%;
|
bottom: 100%;
|
||||||
left: 0;
|
left: 0;
|
||||||
background: white;
|
box-shadow: 0 0 5px rgb(0 0 0 / 25%);
|
||||||
border: 1px solid rgba(0, 0, 0, 0.1);
|
|
||||||
box-shadow: 0 4px 16px rgb(0 0 0 / 12%), 0 1px 3px rgb(0 0 0 / 8%);
|
|
||||||
border-radius: 0.75rem;
|
|
||||||
z-index: 10000;
|
z-index: 10000;
|
||||||
min-width: 330px;
|
min-width: 330px;
|
||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
padding: 4px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.hover-menu::before {
|
|
||||||
content: '';
|
|
||||||
position: absolute;
|
|
||||||
top: 100%;
|
|
||||||
left: 0;
|
|
||||||
width: 100%;
|
|
||||||
height: 8px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.hover-menu > * {
|
|
||||||
border: none !important;
|
|
||||||
box-shadow: none !important;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.hover-menu button {
|
.hover-menu button {
|
||||||
width: 100%;
|
width: 100%;
|
||||||
background: transparent !important;
|
background: white !important;
|
||||||
border: none !important;
|
border-radius: 0 !important;
|
||||||
border-radius: 0.5rem !important;
|
|
||||||
justify-content: space-between;
|
justify-content: space-between;
|
||||||
margin: 0 !important;
|
margin: 0 !important;
|
||||||
height: 36px;
|
height: 36px;
|
||||||
font-weight: 500;
|
border-color: transparent !important;
|
||||||
box-shadow: none !important;
|
}
|
||||||
|
|
||||||
|
.hover-menu button:not(#clear-history-confirm) {
|
||||||
|
border-bottom: 0 !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.hover-menu button:hover {
|
.hover-menu button:hover {
|
||||||
|
|
@ -835,26 +746,19 @@ audio {
|
||||||
}
|
}
|
||||||
|
|
||||||
#show-controls {
|
#show-controls {
|
||||||
background-color: transparent;
|
background-color: white;
|
||||||
border: none !important;
|
border-color: transparent !important;
|
||||||
height: 36px;
|
height: 36px;
|
||||||
border-radius: 0.5rem;
|
border-radius: 0;
|
||||||
|
border-bottom: 0 !important;
|
||||||
padding-top: 3px;
|
padding-top: 3px;
|
||||||
padding-left: 4px;
|
padding-left: 4px;
|
||||||
display: flex;
|
display: flex;
|
||||||
font-weight: normal;
|
font-weight: normal;
|
||||||
}
|
}
|
||||||
|
|
||||||
#show-controls:hover {
|
|
||||||
background-color: #dbeafe;
|
|
||||||
}
|
|
||||||
|
|
||||||
.dark #show-controls {
|
.dark #show-controls {
|
||||||
background-color: transparent;
|
background-color: var(--darker-gray);
|
||||||
}
|
|
||||||
|
|
||||||
.dark #show-controls:hover {
|
|
||||||
background-color: var(--selected-item-color-dark);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#show-controls label {
|
#show-controls label {
|
||||||
|
|
@ -864,12 +768,12 @@ audio {
|
||||||
width: 100%;
|
width: 100%;
|
||||||
padding-right: 12px;
|
padding-right: 12px;
|
||||||
gap: 10px;
|
gap: 10px;
|
||||||
font-weight: 500;
|
font-weight: 600;
|
||||||
color: var(--button-secondary-text-color);
|
color: var(--button-secondary-text-color);
|
||||||
}
|
}
|
||||||
|
|
||||||
#show-controls label input {
|
#show-controls label input {
|
||||||
margin-top: 5px;
|
margin-top: 4px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.transparent-substring {
|
.transparent-substring {
|
||||||
|
|
@ -909,7 +813,7 @@ audio {
|
||||||
}
|
}
|
||||||
|
|
||||||
#chat-input-row {
|
#chat-input-row {
|
||||||
padding: 0.5rem 1rem 1rem;
|
padding: 1rem;
|
||||||
}
|
}
|
||||||
|
|
||||||
#chat-col {
|
#chat-col {
|
||||||
|
|
@ -927,20 +831,9 @@ audio {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body p, .message-body li {
|
.message-body ol, .message-body ul {
|
||||||
line-height: 1.75 !important;
|
|
||||||
}
|
|
||||||
|
|
||||||
.message-body p, .message-body ul, .message-body ol {
|
|
||||||
margin: 1.25em 0 !important;
|
|
||||||
}
|
|
||||||
|
|
||||||
.message-body :is(p, ul, ol):first-child {
|
|
||||||
margin-top: 0 !important;
|
margin-top: 0 !important;
|
||||||
}
|
margin-bottom: 1.25em !important;
|
||||||
|
|
||||||
.message-body :is(p, ul, ol):last-child {
|
|
||||||
margin-bottom: 0 !important;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ----------------------------------------------
|
/* ----------------------------------------------
|
||||||
|
|
@ -1002,7 +895,7 @@ audio {
|
||||||
.options {
|
.options {
|
||||||
z-index: 100 !important;
|
z-index: 100 !important;
|
||||||
border: 1px solid var(--input-border-color);
|
border: 1px solid var(--input-border-color);
|
||||||
border-radius: 0.5rem;
|
border-radius: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ----------------------------------------------
|
/* ----------------------------------------------
|
||||||
|
|
@ -1096,13 +989,9 @@ audio {
|
||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
}
|
}
|
||||||
|
|
||||||
#past-chats label {
|
|
||||||
transition: background-color 0.15s ease;
|
|
||||||
}
|
|
||||||
|
|
||||||
#past-chats .selected,
|
#past-chats .selected,
|
||||||
#past-chats label:hover {
|
#past-chats label:hover {
|
||||||
background-color: #c8d8f5 !important;
|
background-color: #dbeafe !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
#past-chats-buttons,
|
#past-chats-buttons,
|
||||||
|
|
@ -1114,49 +1003,6 @@ audio {
|
||||||
padding-right: 0.5rem;
|
padding-right: 0.5rem;
|
||||||
}
|
}
|
||||||
|
|
||||||
#new-chat-wrapper {
|
|
||||||
display: contents;
|
|
||||||
}
|
|
||||||
|
|
||||||
.new-chat-arrow {
|
|
||||||
cursor: pointer;
|
|
||||||
position: relative;
|
|
||||||
padding: 0;
|
|
||||||
margin-right: -15px;
|
|
||||||
height: 39.594px;
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
}
|
|
||||||
|
|
||||||
.new-chat-menu {
|
|
||||||
display: none;
|
|
||||||
position: absolute;
|
|
||||||
top: 0;
|
|
||||||
left: 0;
|
|
||||||
padding-top: 1.2em;
|
|
||||||
z-index: var(--layer-top);
|
|
||||||
white-space: nowrap;
|
|
||||||
}
|
|
||||||
|
|
||||||
.new-chat-arrow:hover .new-chat-menu {
|
|
||||||
display: block;
|
|
||||||
}
|
|
||||||
|
|
||||||
.new-chat-menu-item {
|
|
||||||
cursor: pointer;
|
|
||||||
padding: var(--size-2);
|
|
||||||
background: var(--background-fill-primary);
|
|
||||||
box-shadow: var(--shadow-drop-lg);
|
|
||||||
border-radius: var(--container-radius);
|
|
||||||
color: var(--body-text-color);
|
|
||||||
font-size: var(--text-md);
|
|
||||||
font-weight: var(--button-large-text-weight);
|
|
||||||
}
|
|
||||||
|
|
||||||
.new-chat-menu-item:hover {
|
|
||||||
background: var(--background-fill-secondary);
|
|
||||||
}
|
|
||||||
|
|
||||||
#past-chats-row,
|
#past-chats-row,
|
||||||
#chat-controls {
|
#chat-controls {
|
||||||
width: 260px;
|
width: 260px;
|
||||||
|
|
@ -1258,7 +1104,7 @@ audio {
|
||||||
Dark theme
|
Dark theme
|
||||||
---------------------------------------------- */
|
---------------------------------------------- */
|
||||||
.dark .header_bar {
|
.dark .header_bar {
|
||||||
background-color: #1a1a1a !important;
|
background-color: var(--darker-gray) !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark .header_bar button.selected {
|
.dark .header_bar button.selected {
|
||||||
|
|
@ -1268,28 +1114,22 @@ audio {
|
||||||
.dark #chat-input textarea {
|
.dark #chat-input textarea {
|
||||||
background: var(--light-gray);
|
background: var(--light-gray);
|
||||||
color: white !important;
|
color: white !important;
|
||||||
border-color: rgba(255, 255, 255, 0.06);
|
border-color: #292c3b;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark #chat-input textarea::placeholder {
|
.dark #chat-input textarea::placeholder {
|
||||||
color: #9ca3af;
|
color: #9ca3af;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark .hover-menu {
|
|
||||||
background: var(--darker-gray);
|
|
||||||
border-color: transparent;
|
|
||||||
box-shadow: 0 4px 16px rgb(0 0 0 / 40%);
|
|
||||||
}
|
|
||||||
|
|
||||||
.dark .hover-menu button {
|
.dark .hover-menu button {
|
||||||
background-color: transparent !important;
|
border-color: var(--border-color-primary);
|
||||||
|
background-color: var(--darker-gray) !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark #chat-controls,
|
.dark #chat-controls,
|
||||||
.dark #past-chats-row {
|
.dark #past-chats-row {
|
||||||
background-color: var(--darker-gray);
|
background-color: var(--darker-gray);
|
||||||
border: 0 !important;
|
border: 0 !important;
|
||||||
box-shadow: none;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark gradio-app .gradio-container.gradio-container-4-37-2 .contain #past-chats .selected,
|
.dark gradio-app .gradio-container.gradio-container-4-37-2 .contain #past-chats .selected,
|
||||||
|
|
@ -1326,11 +1166,11 @@ audio {
|
||||||
Light theme
|
Light theme
|
||||||
---------------------------------------------- */
|
---------------------------------------------- */
|
||||||
.header_bar {
|
.header_bar {
|
||||||
background-color: #e4e8f0 !important;
|
background-color: var(--light-theme-gray) !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.header_bar button.selected {
|
.header_bar button.selected {
|
||||||
background: #c8d8f5;
|
background: #dbeafe;
|
||||||
}
|
}
|
||||||
|
|
||||||
#chat-controls,
|
#chat-controls,
|
||||||
|
|
@ -1339,11 +1179,11 @@ audio {
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark #chat-controls {
|
.dark #chat-controls {
|
||||||
border-left: 1px solid rgba(255, 255, 255, 0.06);
|
border-left: 1px solid #d9d9d0;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark #past-chats-row {
|
.dark #past-chats-row {
|
||||||
border-right: 1px solid rgba(255, 255, 255, 0.06);
|
border-right: 1px solid #d9d9d0;
|
||||||
}
|
}
|
||||||
|
|
||||||
#past-chats-toggle,
|
#past-chats-toggle,
|
||||||
|
|
@ -1444,7 +1284,8 @@ audio {
|
||||||
}
|
}
|
||||||
|
|
||||||
.footer-button svg {
|
.footer-button svg {
|
||||||
stroke: rgb(140 140 148);
|
stroke: rgb(156 163 175);
|
||||||
|
transition: stroke 0.2s;
|
||||||
}
|
}
|
||||||
|
|
||||||
.footer-button:hover svg {
|
.footer-button:hover svg {
|
||||||
|
|
@ -1459,12 +1300,11 @@ audio {
|
||||||
stroke: rgb(209 213 219);
|
stroke: rgb(209 213 219);
|
||||||
}
|
}
|
||||||
|
|
||||||
.block:has(> .label-wrap) {
|
.tgw-accordion {
|
||||||
padding: 10px 12px !important;
|
padding: 10px 12px !important;
|
||||||
border: 1px solid #d2d2d8;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark .block:has(> .label-wrap) {
|
.dark .tgw-accordion {
|
||||||
border: 1px solid var(--border-color-dark);
|
border: 1px solid var(--border-color-dark);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1533,6 +1373,7 @@ audio {
|
||||||
overflow-wrap: break-word;
|
overflow-wrap: break-word;
|
||||||
max-height: 250px;
|
max-height: 250px;
|
||||||
overflow-y: scroll;
|
overflow-y: scroll;
|
||||||
|
contain: layout;
|
||||||
}
|
}
|
||||||
|
|
||||||
.chat .message-body .thinking-content p,
|
.chat .message-body .thinking-content p,
|
||||||
|
|
@ -1629,7 +1470,7 @@ strong {
|
||||||
min-height: 200px;
|
min-height: 200px;
|
||||||
max-height: 65vh;
|
max-height: 65vh;
|
||||||
padding: 10px;
|
padding: 10px;
|
||||||
border-radius: 0.5rem;
|
border-radius: 5px;
|
||||||
border: 1px solid #ccc;
|
border: 1px solid #ccc;
|
||||||
background-color: var(--light-theme-gray);
|
background-color: var(--light-theme-gray);
|
||||||
font-family: inherit;
|
font-family: inherit;
|
||||||
|
|
@ -1657,7 +1498,7 @@ strong {
|
||||||
.edit-control-button {
|
.edit-control-button {
|
||||||
padding: 6px 12px;
|
padding: 6px 12px;
|
||||||
border: 1px solid #ccc;
|
border: 1px solid #ccc;
|
||||||
border-radius: 0.75rem;
|
border-radius: 4px;
|
||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
background-color: #f8f9fa;
|
background-color: #f8f9fa;
|
||||||
color: #212529;
|
color: #212529;
|
||||||
|
|
@ -1821,7 +1662,7 @@ button:focus {
|
||||||
.chat-parent {
|
.chat-parent {
|
||||||
/* Optimize for scrolling performance */
|
/* Optimize for scrolling performance */
|
||||||
will-change: scroll-position;
|
will-change: scroll-position;
|
||||||
contain: style paint;
|
contain: layout style paint;
|
||||||
|
|
||||||
/* Ensure GPU acceleration */
|
/* Ensure GPU acceleration */
|
||||||
transform: translateZ(0);
|
transform: translateZ(0);
|
||||||
|
|
@ -1840,7 +1681,7 @@ button:focus {
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark .sidebar-vertical-separator {
|
.dark .sidebar-vertical-separator {
|
||||||
border-bottom: 1px solid rgba(255, 255, 255, 0.06);
|
border-bottom: 1px solid rgb(255 255 255 / 10%);
|
||||||
}
|
}
|
||||||
|
|
||||||
button#swap-height-width {
|
button#swap-height-width {
|
||||||
|
|
@ -1961,117 +1802,15 @@ table {
|
||||||
border-collapse: collapse;
|
border-collapse: collapse;
|
||||||
}
|
}
|
||||||
|
|
||||||
.table-wrapper {
|
|
||||||
overflow-x: auto;
|
|
||||||
}
|
|
||||||
|
|
||||||
.message-body :is(td, th) {
|
|
||||||
word-break: normal;
|
|
||||||
overflow-wrap: normal;
|
|
||||||
}
|
|
||||||
|
|
||||||
table, tr, td, th, thead {
|
table, tr, td, th, thead {
|
||||||
border: 0;
|
border: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
.prose hr {
|
|
||||||
border-color: var(--border-color-primary);
|
|
||||||
}
|
|
||||||
|
|
||||||
td + td,
|
td + td,
|
||||||
th + th {
|
th + th { border-left: 1px solid; }
|
||||||
border-left: 1px solid var(--border-color-primary) !important;
|
|
||||||
}
|
|
||||||
|
|
||||||
tr + tr td,
|
tr + tr td,
|
||||||
tr + tr th {
|
tr + tr th { border-top: 1px solid; }
|
||||||
border-top: 1px solid var(--border-color-primary) !important;
|
|
||||||
}
|
|
||||||
|
|
||||||
thead + tbody tr:first-child td,
|
thead + tbody tr:first-child td,
|
||||||
thead + tbody tr:first-child th {
|
thead + tbody tr:first-child th { border-top: 1px solid; }
|
||||||
border-top: 1px solid var(--border-color-primary) !important;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ------------------------------------------------
|
|
||||||
Tools CheckboxGroup - vertical DragDrop-like style
|
|
||||||
------------------------------------------------ */
|
|
||||||
|
|
||||||
/* "Refresh list" link in the Tools label */
|
|
||||||
.tools-refresh-link {
|
|
||||||
cursor: pointer;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Checkbox list container */
|
|
||||||
#tools-group {
|
|
||||||
padding: 0 !important;
|
|
||||||
border-width: 0 !important;
|
|
||||||
background: transparent !important;
|
|
||||||
min-height: 0 !important;
|
|
||||||
}
|
|
||||||
|
|
||||||
#tools-group .wrap {
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
flex-wrap: nowrap;
|
|
||||||
gap: 4px;
|
|
||||||
padding: 0;
|
|
||||||
margin-top: var(--spacing-lg);
|
|
||||||
max-height: 350px;
|
|
||||||
overflow-y: auto;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Pretty scrollbar for the tools list */
|
|
||||||
#tools-group .wrap::-webkit-scrollbar {
|
|
||||||
width: 7px;
|
|
||||||
height: 7px;
|
|
||||||
}
|
|
||||||
|
|
||||||
#tools-group .wrap::-webkit-scrollbar-track {
|
|
||||||
background: transparent;
|
|
||||||
}
|
|
||||||
|
|
||||||
#tools-group .wrap::-webkit-scrollbar-thumb,
|
|
||||||
#tools-group .wrap::-webkit-scrollbar-thumb:hover {
|
|
||||||
background: var(--neutral-300);
|
|
||||||
border-radius: 9999px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.dark #tools-group .wrap::-webkit-scrollbar-thumb,
|
|
||||||
.dark #tools-group .wrap::-webkit-scrollbar-thumb:hover {
|
|
||||||
background: rgb(255 255 255 / 6.25%);
|
|
||||||
border-radius: 9999px;
|
|
||||||
}
|
|
||||||
|
|
||||||
#tools-group .wrap::-webkit-scrollbar-corner {
|
|
||||||
background: transparent;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Each checkbox item */
|
|
||||||
#tools-group label {
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
gap: 8px;
|
|
||||||
padding: 5px 8px;
|
|
||||||
border-radius: var(--radius-sm, 4px);
|
|
||||||
background: var(--block-background-fill);
|
|
||||||
border: 1px solid var(--border-color-primary);
|
|
||||||
color: var(--body-text-color);
|
|
||||||
font-size: var(--input-text-size);
|
|
||||||
font-weight: var(--input-text-weight);
|
|
||||||
cursor: pointer;
|
|
||||||
user-select: none;
|
|
||||||
transition: border-color 0.15s ease, background 0.15s ease;
|
|
||||||
box-shadow: none;
|
|
||||||
}
|
|
||||||
|
|
||||||
#tools-group label:hover {
|
|
||||||
border-color: var(--input-border-color-focus);
|
|
||||||
}
|
|
||||||
|
|
||||||
#tools-group label span {
|
|
||||||
flex: 1;
|
|
||||||
overflow: hidden;
|
|
||||||
text-overflow: ellipsis;
|
|
||||||
white-space: nowrap;
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -112,7 +112,7 @@ Used for talking to an instruction-following model using the prompt format defin
|
||||||
|
|
||||||
The prompt format is defined by the **Instruction template** parameter in "Parameters" > "Instruction template", which represents a Jinja2 template.
|
The prompt format is defined by the **Instruction template** parameter in "Parameters" > "Instruction template", which represents a Jinja2 template.
|
||||||
|
|
||||||
Note that when you load a model in the "Model" tab, the web UI will try to automatically detect its instruction template (if any) from the model metadata (e.g. `tokenizer_config.json` or GGUF metadata), and will update the values under "Parameters" > "Instruction template" accordingly. You should check the model card on Hugging Face to see if you are using the correct prompt format.
|
Note that when you load a model in the "Model" tab, the web UI will try to automatically detect its instruction template (if any), and will update the values under "Parameters" > "Instruction template" accordingly. This is done using a set of regular expressions defined in `user_data/models/config.yaml`. This detection is not guaranteed to be accurate. You should check the model card on Hugging Face to see if you are using the correct prompt format.
|
||||||
|
|
||||||
### Chat-instruct
|
### Chat-instruct
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,9 @@ Options:
|
||||||
* **cpu_memory**: Maximum CPU memory in GiB to use for CPU offloading via the accelerate library. Whatever doesn't fit in the GPU or CPU will go to a disk cache if the "disk" checkbox is enabled.
|
* **cpu_memory**: Maximum CPU memory in GiB to use for CPU offloading via the accelerate library. Whatever doesn't fit in the GPU or CPU will go to a disk cache if the "disk" checkbox is enabled.
|
||||||
* **compute_dtype**: Used when "load_in_4bit" is checked. I recommend leaving the default value.
|
* **compute_dtype**: Used when "load_in_4bit" is checked. I recommend leaving the default value.
|
||||||
* **quant_type**: Used when "load_in_4bit" is checked. I recommend leaving the default value.
|
* **quant_type**: Used when "load_in_4bit" is checked. I recommend leaving the default value.
|
||||||
|
* **alpha_value**: Used to extend the context length of a model with a minor loss in quality. I have measured 1.75 to be optimal for 1.5x context, and 2.5 for 2x context. That is, with alpha = 2.5 you can make a model with 4096 context length go to 8192 context length.
|
||||||
|
* **rope_freq_base**: Originally another way to write "alpha_value", it ended up becoming a necessary parameter for some models like CodeLlama, which was fine-tuned with this set to 1000000 and hence needs to be loaded with it set to 1000000 as well.
|
||||||
|
* **compress_pos_emb**: The first and original context-length extension method, discovered by [kaiokendev](https://kaiokendev.github.io/til). When set to 2, the context length is doubled, 3 and it's tripled, etc. It should only be used for models that have been fine-tuned with this parameter set to different than 1. For models that have not been tuned to have greater context length, alpha_value will lead to a smaller accuracy loss.
|
||||||
* **attn_implementation**: Choose the attention implementation. Valid options: `sdpa`, `eager`, `flash_attention_2`. The default (`sdpa`) works well in most cases; `flash_attention_2` may be useful for training.
|
* **attn_implementation**: Choose the attention implementation. Valid options: `sdpa`, `eager`, `flash_attention_2`. The default (`sdpa`) works well in most cases; `flash_attention_2` may be useful for training.
|
||||||
* **cpu**: Loads the model in CPU mode using Pytorch. The model will be loaded in 32-bit precision, so a lot of RAM will be used. CPU inference with transformers is older than llama.cpp and it works, but it's a lot slower. Note: this parameter has a different interpretation in the llama.cpp loader (see above).
|
* **cpu**: Loads the model in CPU mode using Pytorch. The model will be loaded in 32-bit precision, so a lot of RAM will be used. CPU inference with transformers is older than llama.cpp and it works, but it's a lot slower. Note: this parameter has a different interpretation in the llama.cpp loader (see above).
|
||||||
* **load_in_8bit**: Load the model in 8-bit precision using bitsandbytes. The 8-bit kernel in that library has been optimized for training and not inference, so load_in_8bit is slower than load_in_4bit (but more accurate).
|
* **load_in_8bit**: Load the model in 8-bit precision using bitsandbytes. The 8-bit kernel in that library has been optimized for training and not inference, so load_in_8bit is slower than load_in_4bit (but more accurate).
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ A LoRA is tied to a specific model architecture — a LoRA trained on Llama 3 8B
|
||||||
|
|
||||||
### Quick Start
|
### Quick Start
|
||||||
|
|
||||||
1. Load your base model with the **Transformers** loader (no LoRAs loaded).
|
1. Load your base model (no LoRAs loaded).
|
||||||
2. Open the **Training** tab > **Train LoRA**.
|
2. Open the **Training** tab > **Train LoRA**.
|
||||||
3. Pick a dataset and configure parameters (see [below](#parameters)).
|
3. Pick a dataset and configure parameters (see [below](#parameters)).
|
||||||
4. Click **Start LoRA Training** and monitor the [loss](#loss).
|
4. Click **Start LoRA Training** and monitor the [loss](#loss).
|
||||||
|
|
@ -100,8 +100,6 @@ Each parameter has a description in the UI. Below is guidance on the most import
|
||||||
|
|
||||||
VRAM usage during training is roughly similar to inference with ~1000 tokens of context. If you can run the model, you can probably train LoRAs with the default settings. If you run out of VRAM, reduce `Micro Batch Size` or `Cutoff Length`. Training 4-bit quantized models uses more VRAM — set `Micro Batch Size` to `1` to compensate.
|
VRAM usage during training is roughly similar to inference with ~1000 tokens of context. If you can run the model, you can probably train LoRAs with the default settings. If you run out of VRAM, reduce `Micro Batch Size` or `Cutoff Length`. Training 4-bit quantized models uses more VRAM — set `Micro Batch Size` to `1` to compensate.
|
||||||
|
|
||||||
**Gradient checkpointing** is enabled by default. It reduces VRAM usage by recomputing activations during the backward pass instead of storing them in memory. The tradeoff is ~20-30% slower training. There is no impact on accuracy — the results are mathematically identical. The savings are most noticeable with longer sequences and larger batch sizes. You can disable it if you have VRAM to spare and want faster training.
|
|
||||||
|
|
||||||
### Rank
|
### Rank
|
||||||
|
|
||||||
Higher rank = more learning capacity = larger adapter = more VRAM. Use 4–8 for style/format, 128–256 to teach factual knowledge.
|
Higher rank = more learning capacity = larger adapter = more VRAM. Use 4–8 for style/format, 128–256 to teach factual knowledge.
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ If you create an extension, you are welcome to host it in a GitHub repository an
|
||||||
|
|
||||||
|Extension|Description|
|
|Extension|Description|
|
||||||
|---------|-----------|
|
|---------|-----------|
|
||||||
|
|[openai](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/openai)| Creates an API that mimics the OpenAI API and can be used as a drop-in replacement. |
|
||||||
|[superboogav2](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/superboogav2)| Enhanced RAG extension with support for PDF, DOCX, and PPTX files. |
|
|[superboogav2](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/superboogav2)| Enhanced RAG extension with support for PDF, DOCX, and PPTX files. |
|
||||||
|[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. |
|
|[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. |
|
||||||
|[coqui_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/coqui_tts)| Text-to-speech extension using Coqui XTTS v2. |
|
|[coqui_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/coqui_tts)| Text-to-speech extension using Coqui XTTS v2. |
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
## OpenAI/Anthropic-compatible API
|
## OpenAI compatible API
|
||||||
|
|
||||||
The main API for this project is meant to be a drop-in replacement for the OpenAI and Anthropic APIs, including Chat, Completions, and Messages endpoints.
|
The main API for this project is meant to be a drop-in replacement to the OpenAI API, including Chat and Completions endpoints.
|
||||||
|
|
||||||
* It is 100% offline and private.
|
* It is 100% offline and private.
|
||||||
* It doesn't create any logs.
|
* It doesn't create any logs.
|
||||||
|
|
@ -19,7 +19,7 @@ Add `--api` to your command-line flags.
|
||||||
|
|
||||||
### Examples
|
### Examples
|
||||||
|
|
||||||
For the documentation with all the endpoints, parameters and their types, consult `http://127.0.0.1:5000/docs` or the [typing.py](https://github.com/oobabooga/text-generation-webui/blob/main/modules/api/typing.py) file.
|
For the documentation with all the endpoints, parameters and their types, consult `http://127.0.0.1:5000/docs` or the [typing.py](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/openai/typing.py) file.
|
||||||
|
|
||||||
The official examples in the [OpenAI documentation](https://platform.openai.com/docs/api-reference) should also work, and the same parameters apply (although the API here has more optional parameters).
|
The official examples in the [OpenAI documentation](https://platform.openai.com/docs/api-reference) should also work, and the same parameters apply (although the API here has more optional parameters).
|
||||||
|
|
||||||
|
|
@ -39,7 +39,7 @@ curl http://127.0.0.1:5000/v1/completions \
|
||||||
|
|
||||||
#### Chat completions
|
#### Chat completions
|
||||||
|
|
||||||
Works best with instruction-following models. If the "instruction_template" variable is not provided, it will be detected automatically from the model metadata.
|
Works best with instruction-following models. If the "instruction_template" variable is not provided, it will be guessed automatically based on the model name using the regex patterns in `user_data/models/config.yaml`.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://127.0.0.1:5000/v1/chat/completions \
|
curl http://127.0.0.1:5000/v1/chat/completions \
|
||||||
|
|
@ -232,17 +232,6 @@ curl -k http://127.0.0.1:5000/v1/internal/model/load \
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
You can also set a default instruction template for all subsequent API requests by passing `instruction_template` (a template name from `user_data/instruction-templates/`) or `instruction_template_str` (a raw Jinja2 string):
|
|
||||||
|
|
||||||
```shell
|
|
||||||
curl -k http://127.0.0.1:5000/v1/internal/model/load \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"model_name": "Qwen_Qwen3-0.6B-Q4_K_M.gguf",
|
|
||||||
"instruction_template": "Alpaca"
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Python chat example
|
#### Python chat example
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|
@ -501,6 +490,16 @@ The following environment variables can be used (they take precedence over every
|
||||||
| `OPENEDAI_EMBEDDING_MODEL` | Embedding model (if applicable) | sentence-transformers/all-mpnet-base-v2 |
|
| `OPENEDAI_EMBEDDING_MODEL` | Embedding model (if applicable) | sentence-transformers/all-mpnet-base-v2 |
|
||||||
| `OPENEDAI_EMBEDDING_DEVICE` | Embedding device (if applicable) | cuda |
|
| `OPENEDAI_EMBEDDING_DEVICE` | Embedding device (if applicable) | cuda |
|
||||||
|
|
||||||
|
#### Persistent settings with `settings.yaml`
|
||||||
|
|
||||||
|
You can also set the following variables in your `settings.yaml` file:
|
||||||
|
|
||||||
|
```
|
||||||
|
openai-embedding_device: cuda
|
||||||
|
openai-embedding_model: "sentence-transformers/all-mpnet-base-v2"
|
||||||
|
openai-debug: 1
|
||||||
|
```
|
||||||
|
|
||||||
### Third-party application setup
|
### Third-party application setup
|
||||||
|
|
||||||
You can usually force an application that uses the OpenAI API to connect to the local API by using the following environment variables:
|
You can usually force an application that uses the OpenAI API to connect to the local API by using the following environment variables:
|
||||||
|
|
|
||||||
|
|
@ -1,172 +0,0 @@
|
||||||
## Tool calling in the UI
|
|
||||||
|
|
||||||
### 1. Load a model with tool-calling support
|
|
||||||
|
|
||||||
Load a model with tool-calling support from the Model tab.
|
|
||||||
|
|
||||||
### 2. Select tools
|
|
||||||
|
|
||||||
In the chat sidebar, check the tools you want the model to use:
|
|
||||||
|
|
||||||
- `web_search`: Search the web using DuckDuckGo.
|
|
||||||
- `fetch_webpage`: Fetch the content of a URL.
|
|
||||||
- `calculate`: Evaluate math expressions.
|
|
||||||
- `get_datetime`: Get the current date and time.
|
|
||||||
- `roll_dice`: Roll dice.
|
|
||||||
|
|
||||||
### 3. Chat
|
|
||||||
|
|
||||||
Send a message as usual. When the model decides it needs a tool, it will call it automatically. You will see each tool call and its result in a collapsible accordion inside the chat message.
|
|
||||||
|
|
||||||
The model may call multiple tools in sequence before giving its final answer.
|
|
||||||
|
|
||||||
## Writing custom tools
|
|
||||||
|
|
||||||
Each tool is a single `.py` file in `user_data/tools/`. It needs two things:
|
|
||||||
|
|
||||||
1. A `tool` dictionary that describes the function (name, description, parameters).
|
|
||||||
2. An `execute(arguments)` function that runs it and returns the result.
|
|
||||||
|
|
||||||
Here is a minimal example (`user_data/tools/get_datetime.py`):
|
|
||||||
|
|
||||||
```python
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
tool = {
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "get_datetime",
|
|
||||||
"description": "Get the current date and time.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def execute(arguments):
|
|
||||||
now = datetime.now()
|
|
||||||
return {"date": now.strftime("%Y-%m-%d"), "time": now.strftime("%I:%M %p")}
|
|
||||||
```
|
|
||||||
|
|
||||||
An example with parameters (`user_data/tools/roll_dice.py`):
|
|
||||||
|
|
||||||
```python
|
|
||||||
import random
|
|
||||||
|
|
||||||
tool = {
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "roll_dice",
|
|
||||||
"description": "Roll one or more dice with the specified number of sides.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"count": {"type": "integer", "description": "Number of dice to roll.", "default": 1},
|
|
||||||
"sides": {"type": "integer", "description": "Number of sides per die.", "default": 20},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def execute(arguments):
|
|
||||||
count = max(1, min(arguments.get("count", 1), 1000))
|
|
||||||
sides = max(2, min(arguments.get("sides", 20), 1000))
|
|
||||||
rolls = [random.randint(1, sides) for _ in range(count)]
|
|
||||||
return {"rolls": rolls, "total": sum(rolls)}
|
|
||||||
```
|
|
||||||
|
|
||||||
You can open the built-in tools in `user_data/tools/` for more examples.
|
|
||||||
|
|
||||||
## MCP servers
|
|
||||||
|
|
||||||
You can connect to remote [MCP (Model Context Protocol)](https://modelcontextprotocol.io/) servers to use their tools alongside local ones.
|
|
||||||
|
|
||||||
In the chat sidebar, open the **MCP servers** accordion and enter one server URL per line. For servers that require authentication, append headers after the URL separated by commas:
|
|
||||||
|
|
||||||
```
|
|
||||||
https://example.com/mcp
|
|
||||||
https://other.com/mcp,Authorization: Bearer sk-xxx
|
|
||||||
```
|
|
||||||
|
|
||||||
All tools from the configured servers are automatically discovered and made available to the model during generation. If an MCP tool has the same name as a selected local tool, the local tool takes priority.
|
|
||||||
|
|
||||||
## Tool calling over the API
|
|
||||||
|
|
||||||
Tool calling over the API follows the [OpenAI API](https://platform.openai.com/docs/guides/function-calling) convention. Define your tools, send them with your messages, and handle tool calls in a loop until the model gives a final answer.
|
|
||||||
|
|
||||||
```python
|
|
||||||
import json
|
|
||||||
import requests
|
|
||||||
|
|
||||||
url = "http://127.0.0.1:5000/v1/chat/completions"
|
|
||||||
|
|
||||||
tools = [
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "get_weather",
|
|
||||||
"description": "Get the current weather for a given location.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"location": {"type": "string", "description": "City name"},
|
|
||||||
},
|
|
||||||
"required": ["location"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def execute_tool(name, arguments):
|
|
||||||
if name == "get_weather":
|
|
||||||
return {"temperature": "14°C", "condition": "partly cloudy"}
|
|
||||||
return {"error": f"Unknown tool: {name}"}
|
|
||||||
|
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "What's the weather like in Paris?"}]
|
|
||||||
|
|
||||||
for _ in range(10):
|
|
||||||
response = requests.post(url, json={"messages": messages, "tools": tools}).json()
|
|
||||||
choice = response["choices"][0]
|
|
||||||
|
|
||||||
if choice["finish_reason"] == "tool_calls":
|
|
||||||
messages.append({
|
|
||||||
"role": "assistant",
|
|
||||||
"content": choice["message"]["content"],
|
|
||||||
"tool_calls": choice["message"]["tool_calls"],
|
|
||||||
})
|
|
||||||
|
|
||||||
for tool_call in choice["message"]["tool_calls"]:
|
|
||||||
name = tool_call["function"]["name"]
|
|
||||||
arguments = json.loads(tool_call["function"]["arguments"])
|
|
||||||
result = execute_tool(name, arguments)
|
|
||||||
print(f"Tool call: {name}({arguments}) => {result}")
|
|
||||||
|
|
||||||
messages.append({
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": tool_call["id"],
|
|
||||||
"content": json.dumps(result),
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
print(f"\nAssistant: {choice['message']['content']}")
|
|
||||||
break
|
|
||||||
```
|
|
||||||
|
|
||||||
## Supported models
|
|
||||||
|
|
||||||
The following models are supported:
|
|
||||||
|
|
||||||
- Qwen 3.5
|
|
||||||
- GPT-OSS
|
|
||||||
- Mistral Small / Devstral
|
|
||||||
- DeepSeek V3
|
|
||||||
- Kimi-K2
|
|
||||||
- MiniMax-M2.5
|
|
||||||
- GLM-5
|
|
||||||
- Llama 4
|
|
||||||
|
|
||||||
Other models that output tool calls as JSON (inside XML tags, code blocks, or plain JSON) are also supported through a generic fallback parser.
|
|
||||||
|
|
@ -158,21 +158,28 @@ class ModelDownloader:
|
||||||
# Also if GGUF and safetensors are available, download only safetensors
|
# Also if GGUF and safetensors are available, download only safetensors
|
||||||
if (has_pytorch or has_pt or has_gguf) and has_safetensors:
|
if (has_pytorch or has_pt or has_gguf) and has_safetensors:
|
||||||
has_gguf = False
|
has_gguf = False
|
||||||
keep = [i for i, c in enumerate(classifications) if c not in ['pytorch', 'pt', 'gguf']]
|
for i in range(len(classifications) - 1, -1, -1):
|
||||||
links = [links[i] for i in keep]
|
if classifications[i] in ['pytorch', 'pt', 'gguf']:
|
||||||
file_sizes = [file_sizes[i] for i in keep]
|
links.pop(i)
|
||||||
|
file_sizes.pop(i)
|
||||||
|
|
||||||
# For GGUF, try to download only the Q4_K_M if no specific file is specified.
|
# For GGUF, try to download only the Q4_K_M if no specific file is specified.
|
||||||
if has_gguf and specific_file is None:
|
if has_gguf and specific_file is None:
|
||||||
has_q4km = any('q4_k_m' in link.lower() for link in links)
|
has_q4km = False
|
||||||
|
for i in range(len(classifications) - 1, -1, -1):
|
||||||
|
if 'q4_k_m' in links[i].lower():
|
||||||
|
has_q4km = True
|
||||||
|
|
||||||
if has_q4km:
|
if has_q4km:
|
||||||
keep = [i for i, link in enumerate(links) if 'q4_k_m' in link.lower()]
|
for i in range(len(classifications) - 1, -1, -1):
|
||||||
|
if 'q4_k_m' not in links[i].lower():
|
||||||
|
links.pop(i)
|
||||||
|
file_sizes.pop(i)
|
||||||
else:
|
else:
|
||||||
keep = [i for i, link in enumerate(links) if not link.lower().endswith('.gguf')]
|
for i in range(len(classifications) - 1, -1, -1):
|
||||||
|
if links[i].lower().endswith('.gguf'):
|
||||||
links = [links[i] for i in keep]
|
links.pop(i)
|
||||||
file_sizes = [file_sizes[i] for i in keep]
|
file_sizes.pop(i)
|
||||||
|
|
||||||
is_llamacpp = has_gguf and specific_file is not None
|
is_llamacpp = has_gguf and specific_file is not None
|
||||||
return links, sha256, is_lora, is_llamacpp, file_sizes
|
return links, sha256, is_lora, is_llamacpp, file_sizes
|
||||||
|
|
|
||||||
635
extensions/openai/completions.py
Normal file
635
extensions/openai/completions.py
Normal file
|
|
@ -0,0 +1,635 @@
|
||||||
|
import copy
|
||||||
|
import functools
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from collections import deque
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
import yaml
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from extensions.openai.errors import InvalidRequestError
|
||||||
|
from extensions.openai.typing import ToolDefinition
|
||||||
|
from extensions.openai.utils import debug_msg, getToolCallId, parseToolCall
|
||||||
|
from modules import shared
|
||||||
|
from modules.chat import (
|
||||||
|
generate_chat_prompt,
|
||||||
|
generate_chat_reply,
|
||||||
|
load_character_memoized,
|
||||||
|
load_instruction_template_memoized
|
||||||
|
)
|
||||||
|
from modules.image_utils import convert_openai_messages_to_images
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
from modules.presets import load_preset_memoized
|
||||||
|
from modules.text_generation import decode, encode, generate_reply
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def load_chat_template_file(filepath):
|
||||||
|
"""Load a chat template from a file path (.jinja, .jinja2, or .yaml/.yml)."""
|
||||||
|
filepath = Path(filepath)
|
||||||
|
ext = filepath.suffix.lower()
|
||||||
|
text = filepath.read_text(encoding='utf-8')
|
||||||
|
if ext in ['.yaml', '.yml']:
|
||||||
|
data = yaml.safe_load(text)
|
||||||
|
return data.get('instruction_template', '')
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def convert_logprobs_to_tiktoken(model, logprobs):
|
||||||
|
# more problems than it's worth.
|
||||||
|
# try:
|
||||||
|
# encoder = tiktoken.encoding_for_model(model)
|
||||||
|
# # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
|
||||||
|
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
|
||||||
|
# except KeyError:
|
||||||
|
# # assume native tokens if we can't find the tokenizer
|
||||||
|
# return logprobs
|
||||||
|
|
||||||
|
return logprobs
|
||||||
|
|
||||||
|
|
||||||
|
def process_parameters(body, is_legacy=False):
|
||||||
|
generate_params = body
|
||||||
|
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||||
|
generate_params['max_new_tokens'] = body.pop(max_tokens_str)
|
||||||
|
if generate_params['truncation_length'] == 0:
|
||||||
|
generate_params['truncation_length'] = shared.settings['truncation_length']
|
||||||
|
|
||||||
|
if generate_params['temperature'] == 0:
|
||||||
|
generate_params['do_sample'] = False
|
||||||
|
generate_params['top_k'] = 1
|
||||||
|
|
||||||
|
if body['preset'] is not None:
|
||||||
|
preset = load_preset_memoized(body['preset'])
|
||||||
|
generate_params.update(preset)
|
||||||
|
|
||||||
|
generate_params['custom_stopping_strings'] = []
|
||||||
|
if 'stop' in body: # str or array, max len 4 (ignored)
|
||||||
|
if isinstance(body['stop'], str):
|
||||||
|
generate_params['custom_stopping_strings'] = [body['stop']]
|
||||||
|
elif isinstance(body['stop'], list):
|
||||||
|
generate_params['custom_stopping_strings'] = body['stop']
|
||||||
|
|
||||||
|
if shared.args.loader != 'llama.cpp':
|
||||||
|
from transformers import LogitsProcessorList
|
||||||
|
|
||||||
|
from modules.transformers_loader import (
|
||||||
|
LogitsBiasProcessor,
|
||||||
|
LogprobProcessor
|
||||||
|
)
|
||||||
|
|
||||||
|
logits_processor = []
|
||||||
|
logit_bias = body.get('logit_bias', None)
|
||||||
|
if logit_bias: # {str: float, ...}
|
||||||
|
logits_processor = [LogitsBiasProcessor(logit_bias)]
|
||||||
|
|
||||||
|
logprobs = None # coming to chat eventually
|
||||||
|
if 'logprobs' in body:
|
||||||
|
logprobs = body.get('logprobs', 0) # maybe cap at topk? don't clamp 0-5.
|
||||||
|
generate_params['logprob_proc'] = LogprobProcessor(logprobs)
|
||||||
|
logits_processor.extend([generate_params['logprob_proc']])
|
||||||
|
else:
|
||||||
|
logprobs = None
|
||||||
|
|
||||||
|
if logits_processor: # requires logits_processor support
|
||||||
|
generate_params['logits_processor'] = LogitsProcessorList(logits_processor)
|
||||||
|
|
||||||
|
return generate_params
|
||||||
|
|
||||||
|
|
||||||
|
def process_multimodal_content(content):
|
||||||
|
"""Extract text and add image placeholders from OpenAI multimodal format"""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
|
||||||
|
if isinstance(content, list):
|
||||||
|
text_parts = []
|
||||||
|
image_placeholders = ""
|
||||||
|
for item in content:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
item_type = item.get('type', '')
|
||||||
|
if item_type == 'text':
|
||||||
|
text_parts.append(item.get('text', ''))
|
||||||
|
elif item_type == 'image_url':
|
||||||
|
image_placeholders += "<__media__>"
|
||||||
|
|
||||||
|
final_text = ' '.join(text_parts)
|
||||||
|
if image_placeholders:
|
||||||
|
return f"{image_placeholders}\n\n{final_text}"
|
||||||
|
else:
|
||||||
|
return final_text
|
||||||
|
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_history(history):
|
||||||
|
'''
|
||||||
|
Chat histories in this program are in the format [message, reply].
|
||||||
|
This function converts OpenAI histories to that format.
|
||||||
|
'''
|
||||||
|
chat_dialogue = []
|
||||||
|
current_message = ""
|
||||||
|
current_reply = ""
|
||||||
|
user_input = ""
|
||||||
|
user_input_last = True
|
||||||
|
system_message = ""
|
||||||
|
|
||||||
|
for entry in history:
|
||||||
|
content = entry["content"]
|
||||||
|
role = entry["role"]
|
||||||
|
|
||||||
|
if role == "user":
|
||||||
|
# Extract text content (images handled by model-specific code)
|
||||||
|
content = process_multimodal_content(content)
|
||||||
|
user_input = content
|
||||||
|
user_input_last = True
|
||||||
|
|
||||||
|
if current_message:
|
||||||
|
chat_dialogue.append([current_message, '', '', {}])
|
||||||
|
current_message = ""
|
||||||
|
|
||||||
|
current_message = content
|
||||||
|
elif role == "assistant":
|
||||||
|
meta = {}
|
||||||
|
tool_calls = entry.get("tool_calls")
|
||||||
|
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
|
||||||
|
meta["tool_calls"] = tool_calls
|
||||||
|
if content.strip() == "":
|
||||||
|
content = "" # keep empty content, don't skip
|
||||||
|
|
||||||
|
current_reply = content
|
||||||
|
user_input_last = False
|
||||||
|
if current_message:
|
||||||
|
chat_dialogue.append([current_message, current_reply, '', meta])
|
||||||
|
current_message = ""
|
||||||
|
current_reply = ""
|
||||||
|
else:
|
||||||
|
chat_dialogue.append(['', current_reply, '', meta])
|
||||||
|
elif role == "tool":
|
||||||
|
user_input_last = False
|
||||||
|
meta = {}
|
||||||
|
if "tool_call_id" in entry:
|
||||||
|
meta["tool_call_id"] = entry["tool_call_id"]
|
||||||
|
chat_dialogue.append(['', '', content, meta])
|
||||||
|
elif role == "system":
|
||||||
|
system_message += f"\n{content}" if system_message else content
|
||||||
|
|
||||||
|
if not user_input_last:
|
||||||
|
user_input = ""
|
||||||
|
|
||||||
|
return user_input, system_message, {
|
||||||
|
'internal': chat_dialogue,
|
||||||
|
'visible': copy.deepcopy(chat_dialogue),
|
||||||
|
'messages': history # Store original messages for multimodal models
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, prompt_only=False, stop_event=None) -> dict:
|
||||||
|
if body.get('functions', []):
|
||||||
|
raise InvalidRequestError(message="functions is not supported.", param='functions')
|
||||||
|
|
||||||
|
if body.get('function_call', ''):
|
||||||
|
raise InvalidRequestError(message="function_call is not supported.", param='function_call')
|
||||||
|
|
||||||
|
if 'messages' not in body:
|
||||||
|
raise InvalidRequestError(message="messages is required", param='messages')
|
||||||
|
|
||||||
|
tools = None
|
||||||
|
if 'tools' in body and body['tools'] is not None and isinstance(body['tools'], list) and len(body['tools']) > 0:
|
||||||
|
tools = validateTools(body['tools']) # raises InvalidRequestError if validation fails
|
||||||
|
|
||||||
|
messages = body['messages']
|
||||||
|
for m in messages:
|
||||||
|
if 'role' not in m:
|
||||||
|
raise InvalidRequestError(message="messages: missing role", param='messages')
|
||||||
|
elif m['role'] == 'function':
|
||||||
|
raise InvalidRequestError(message="role: function is not supported.", param='messages')
|
||||||
|
|
||||||
|
# Handle multimodal content validation
|
||||||
|
content = m.get('content')
|
||||||
|
if content is None:
|
||||||
|
# OpenAI allows content: null on assistant messages when tool_calls is present
|
||||||
|
if m['role'] == 'assistant' and m.get('tool_calls'):
|
||||||
|
m['content'] = ''
|
||||||
|
else:
|
||||||
|
raise InvalidRequestError(message="messages: missing content", param='messages')
|
||||||
|
|
||||||
|
# Validate multimodal content structure
|
||||||
|
if isinstance(content, list):
|
||||||
|
for item in content:
|
||||||
|
if not isinstance(item, dict) or 'type' not in item:
|
||||||
|
raise InvalidRequestError(message="messages: invalid content item format", param='messages')
|
||||||
|
if item['type'] not in ['text', 'image_url']:
|
||||||
|
raise InvalidRequestError(message="messages: unsupported content type", param='messages')
|
||||||
|
if item['type'] == 'text' and 'text' not in item:
|
||||||
|
raise InvalidRequestError(message="messages: missing text in content item", param='messages')
|
||||||
|
if item['type'] == 'image_url' and ('image_url' not in item or 'url' not in item['image_url']):
|
||||||
|
raise InvalidRequestError(message="messages: missing image_url in content item", param='messages')
|
||||||
|
|
||||||
|
# Chat Completions
|
||||||
|
object_type = 'chat.completion' if not stream else 'chat.completion.chunk'
|
||||||
|
created_time = int(time.time())
|
||||||
|
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
|
||||||
|
resp_list = 'data' if is_legacy else 'choices'
|
||||||
|
|
||||||
|
# generation parameters
|
||||||
|
generate_params = process_parameters(body, is_legacy=is_legacy)
|
||||||
|
if stop_event is not None:
|
||||||
|
generate_params['stop_event'] = stop_event
|
||||||
|
continue_ = body['continue_']
|
||||||
|
|
||||||
|
# Instruction template
|
||||||
|
if body['instruction_template_str']:
|
||||||
|
instruction_template_str = body['instruction_template_str']
|
||||||
|
elif body['instruction_template']:
|
||||||
|
instruction_template = body['instruction_template']
|
||||||
|
instruction_template = "Alpaca" if instruction_template == "None" else instruction_template
|
||||||
|
instruction_template_str = load_instruction_template_memoized(instruction_template)
|
||||||
|
elif shared.args.chat_template_file:
|
||||||
|
instruction_template_str = load_chat_template_file(shared.args.chat_template_file)
|
||||||
|
else:
|
||||||
|
instruction_template_str = shared.settings['instruction_template_str']
|
||||||
|
|
||||||
|
chat_template_str = body['chat_template_str'] or shared.default_settings['chat_template_str']
|
||||||
|
chat_instruct_command = body['chat_instruct_command'] or shared.default_settings['chat-instruct_command']
|
||||||
|
|
||||||
|
# Chat character
|
||||||
|
character = body['character'] or shared.default_settings['character']
|
||||||
|
character = "Assistant" if character == "None" else character
|
||||||
|
name1 = body['user_name'] or shared.default_settings['name1']
|
||||||
|
name1, name2, _, greeting, context = load_character_memoized(character, name1, '')
|
||||||
|
name2 = body['bot_name'] or name2
|
||||||
|
context = body['context'] or context
|
||||||
|
greeting = body['greeting'] or greeting
|
||||||
|
user_bio = body['user_bio'] or ''
|
||||||
|
|
||||||
|
# History
|
||||||
|
user_input, custom_system_message, history = convert_history(messages)
|
||||||
|
|
||||||
|
generate_params.update({
|
||||||
|
'mode': body['mode'],
|
||||||
|
'name1': name1,
|
||||||
|
'name2': name2,
|
||||||
|
'context': context,
|
||||||
|
'greeting': greeting,
|
||||||
|
'user_bio': user_bio,
|
||||||
|
'instruction_template_str': instruction_template_str,
|
||||||
|
'custom_system_message': custom_system_message,
|
||||||
|
'chat_template_str': chat_template_str,
|
||||||
|
'chat-instruct_command': chat_instruct_command,
|
||||||
|
'tools': tools,
|
||||||
|
'history': history,
|
||||||
|
'stream': stream
|
||||||
|
})
|
||||||
|
|
||||||
|
max_tokens = generate_params['max_new_tokens']
|
||||||
|
if max_tokens in [None, 0]:
|
||||||
|
generate_params['max_new_tokens'] = 512
|
||||||
|
generate_params['auto_max_new_tokens'] = True
|
||||||
|
|
||||||
|
requested_model = generate_params.pop('model')
|
||||||
|
logprob_proc = generate_params.pop('logprob_proc', None)
|
||||||
|
|
||||||
|
def chat_streaming_chunk(content, chunk_tool_calls=None):
|
||||||
|
# begin streaming
|
||||||
|
chunk = {
|
||||||
|
"id": cmpl_id,
|
||||||
|
"object": object_type,
|
||||||
|
"created": created_time,
|
||||||
|
"model": shared.model_name,
|
||||||
|
resp_list: [{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": None,
|
||||||
|
"delta": {'role': 'assistant', 'content': content, 'tool_calls': chunk_tool_calls},
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
|
||||||
|
if logprob_proc: # not official for chat yet
|
||||||
|
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||||
|
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||||
|
# else:
|
||||||
|
# chunk[resp_list][0]["logprobs"] = None
|
||||||
|
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
# generate reply #######################################
|
||||||
|
if prompt_only:
|
||||||
|
prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_)
|
||||||
|
yield {'prompt': prompt}
|
||||||
|
return
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
yield chat_streaming_chunk('')
|
||||||
|
|
||||||
|
generator = generate_chat_reply(
|
||||||
|
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
|
||||||
|
|
||||||
|
answer = ''
|
||||||
|
seen_content = ''
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
|
end_last_tool_call = 0
|
||||||
|
supported_tools = [x["function"]["name"] for x in tools] if tools is not None else None
|
||||||
|
|
||||||
|
for a in generator:
|
||||||
|
answer = a['internal'][-1][1]
|
||||||
|
|
||||||
|
if supported_tools is not None:
|
||||||
|
tool_call = parseToolCall(answer[end_last_tool_call:], supported_tools) if len(answer) > 0 else []
|
||||||
|
if len(tool_call) > 0:
|
||||||
|
for tc in tool_call:
|
||||||
|
tc["id"] = getToolCallId()
|
||||||
|
tc["index"] = len(tool_calls)
|
||||||
|
tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"])
|
||||||
|
tool_calls.append(tc)
|
||||||
|
end_last_tool_call = len(answer)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
len_seen = len(seen_content)
|
||||||
|
new_content = answer[len_seen:]
|
||||||
|
|
||||||
|
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk = chat_streaming_chunk(new_content)
|
||||||
|
|
||||||
|
seen_content = answer
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# stop generation if tool_calls were generated previously
|
||||||
|
if len(tool_calls) > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
token_count = shared.model.last_prompt_token_count if hasattr(shared.model, 'last_prompt_token_count') else 0
|
||||||
|
completion_token_count = len(encode(answer)[0])
|
||||||
|
stop_reason = "stop"
|
||||||
|
if len(tool_calls) > 0:
|
||||||
|
stop_reason = "tool_calls"
|
||||||
|
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
|
||||||
|
stop_reason = "length"
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
chunk = chat_streaming_chunk('', tool_calls)
|
||||||
|
chunk[resp_list][0]['finish_reason'] = stop_reason
|
||||||
|
chunk['usage'] = {
|
||||||
|
"prompt_tokens": token_count,
|
||||||
|
"completion_tokens": completion_token_count,
|
||||||
|
"total_tokens": token_count + completion_token_count
|
||||||
|
}
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
resp = {
|
||||||
|
"id": cmpl_id,
|
||||||
|
"object": object_type,
|
||||||
|
"created": created_time,
|
||||||
|
"model": shared.model_name,
|
||||||
|
resp_list: [{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": stop_reason,
|
||||||
|
"message": {"role": "assistant", "content": answer, **({"tool_calls": tool_calls} if tool_calls else {})},
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": token_count,
|
||||||
|
"completion_tokens": completion_token_count,
|
||||||
|
"total_tokens": token_count + completion_token_count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if logprob_proc: # not official for chat yet
|
||||||
|
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||||
|
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||||
|
# else:
|
||||||
|
# resp[resp_list][0]["logprobs"] = None
|
||||||
|
|
||||||
|
yield resp
|
||||||
|
|
||||||
|
|
||||||
|
def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_event=None):
|
||||||
|
object_type = 'text_completion'
|
||||||
|
created_time = int(time.time())
|
||||||
|
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
|
||||||
|
resp_list = 'data' if is_legacy else 'choices'
|
||||||
|
|
||||||
|
prompt_str = 'context' if is_legacy else 'prompt'
|
||||||
|
|
||||||
|
# Handle both prompt and messages format for unified multimodal support
|
||||||
|
if prompt_str not in body or body[prompt_str] is None:
|
||||||
|
if 'messages' in body:
|
||||||
|
# Convert messages format to prompt for completions endpoint
|
||||||
|
prompt_text = ""
|
||||||
|
for message in body.get('messages', []):
|
||||||
|
if isinstance(message, dict) and 'content' in message:
|
||||||
|
# Extract text content from multimodal messages
|
||||||
|
content = message['content']
|
||||||
|
if isinstance(content, str):
|
||||||
|
prompt_text += content
|
||||||
|
elif isinstance(content, list):
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, dict) and item.get('type') == 'text':
|
||||||
|
prompt_text += item.get('text', '')
|
||||||
|
|
||||||
|
# Allow empty prompts for image-only requests
|
||||||
|
body[prompt_str] = prompt_text
|
||||||
|
else:
|
||||||
|
raise InvalidRequestError("Missing required input", param=prompt_str)
|
||||||
|
|
||||||
|
# common params
|
||||||
|
generate_params = process_parameters(body, is_legacy=is_legacy)
|
||||||
|
max_tokens = generate_params['max_new_tokens']
|
||||||
|
generate_params['stream'] = stream
|
||||||
|
if stop_event is not None:
|
||||||
|
generate_params['stop_event'] = stop_event
|
||||||
|
requested_model = generate_params.pop('model')
|
||||||
|
logprob_proc = generate_params.pop('logprob_proc', None)
|
||||||
|
suffix = body['suffix'] if body['suffix'] else ''
|
||||||
|
echo = body['echo']
|
||||||
|
|
||||||
|
# Add messages to generate_params if present for multimodal processing
|
||||||
|
if body.get('messages'):
|
||||||
|
generate_params['messages'] = body['messages']
|
||||||
|
raw_images = convert_openai_messages_to_images(generate_params['messages'])
|
||||||
|
if raw_images:
|
||||||
|
logger.info(f"Found {len(raw_images)} image(s) in request.")
|
||||||
|
generate_params['raw_images'] = raw_images
|
||||||
|
|
||||||
|
if not stream:
|
||||||
|
prompt_arg = body[prompt_str]
|
||||||
|
|
||||||
|
# Handle empty/None prompts (e.g., image-only requests)
|
||||||
|
if prompt_arg is None:
|
||||||
|
prompt_arg = ""
|
||||||
|
|
||||||
|
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and len(prompt_arg) > 0 and isinstance(prompt_arg[0], int)):
|
||||||
|
prompt_arg = [prompt_arg]
|
||||||
|
|
||||||
|
resp_list_data = []
|
||||||
|
total_completion_token_count = 0
|
||||||
|
total_prompt_token_count = 0
|
||||||
|
|
||||||
|
for idx, prompt in enumerate(prompt_arg, start=0):
|
||||||
|
if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], int):
|
||||||
|
# token lists
|
||||||
|
if requested_model == shared.model_name:
|
||||||
|
prompt = decode(prompt)[0]
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
encoder = tiktoken.encoding_for_model(requested_model)
|
||||||
|
prompt = encoder.decode(prompt)
|
||||||
|
except KeyError:
|
||||||
|
prompt = decode(prompt)[0]
|
||||||
|
|
||||||
|
prefix = prompt if echo else ''
|
||||||
|
|
||||||
|
# generate reply #######################################
|
||||||
|
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||||
|
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||||
|
answer = ''
|
||||||
|
|
||||||
|
for a in generator:
|
||||||
|
answer = a
|
||||||
|
|
||||||
|
token_count = len(encode(prompt)[0])
|
||||||
|
total_prompt_token_count += token_count
|
||||||
|
completion_token_count = len(encode(answer)[0])
|
||||||
|
total_completion_token_count += completion_token_count
|
||||||
|
stop_reason = "stop"
|
||||||
|
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||||
|
stop_reason = "length"
|
||||||
|
|
||||||
|
respi = {
|
||||||
|
"index": idx,
|
||||||
|
"finish_reason": stop_reason,
|
||||||
|
"text": prefix + answer + suffix,
|
||||||
|
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp_list_data.extend([respi])
|
||||||
|
|
||||||
|
resp = {
|
||||||
|
"id": cmpl_id,
|
||||||
|
"object": object_type,
|
||||||
|
"created": created_time,
|
||||||
|
"model": shared.model_name,
|
||||||
|
resp_list: resp_list_data,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": total_prompt_token_count,
|
||||||
|
"completion_tokens": total_completion_token_count,
|
||||||
|
"total_tokens": total_prompt_token_count + total_completion_token_count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
yield resp
|
||||||
|
else:
|
||||||
|
prompt = body[prompt_str]
|
||||||
|
if isinstance(prompt, list):
|
||||||
|
if prompt and isinstance(prompt[0], int):
|
||||||
|
try:
|
||||||
|
encoder = tiktoken.encoding_for_model(requested_model)
|
||||||
|
prompt = encoder.decode(prompt)
|
||||||
|
except KeyError:
|
||||||
|
prompt = decode(prompt)[0]
|
||||||
|
else:
|
||||||
|
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
||||||
|
|
||||||
|
prefix = prompt if echo else ''
|
||||||
|
token_count = len(encode(prompt)[0])
|
||||||
|
|
||||||
|
def text_streaming_chunk(content):
|
||||||
|
# begin streaming
|
||||||
|
chunk = {
|
||||||
|
"id": cmpl_id,
|
||||||
|
"object": object_type,
|
||||||
|
"created": created_time,
|
||||||
|
"model": shared.model_name,
|
||||||
|
resp_list: [{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": None,
|
||||||
|
"text": content,
|
||||||
|
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
yield text_streaming_chunk(prefix)
|
||||||
|
|
||||||
|
# generate reply #######################################
|
||||||
|
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||||
|
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||||
|
answer = ''
|
||||||
|
seen_content = ''
|
||||||
|
completion_token_count = 0
|
||||||
|
|
||||||
|
for a in generator:
|
||||||
|
answer = a
|
||||||
|
|
||||||
|
len_seen = len(seen_content)
|
||||||
|
new_content = answer[len_seen:]
|
||||||
|
|
||||||
|
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||||
|
continue
|
||||||
|
|
||||||
|
seen_content = answer
|
||||||
|
chunk = text_streaming_chunk(new_content)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
completion_token_count = len(encode(answer)[0])
|
||||||
|
stop_reason = "stop"
|
||||||
|
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||||
|
stop_reason = "length"
|
||||||
|
|
||||||
|
chunk = text_streaming_chunk(suffix)
|
||||||
|
chunk[resp_list][0]["finish_reason"] = stop_reason
|
||||||
|
chunk["usage"] = {
|
||||||
|
"prompt_tokens": token_count,
|
||||||
|
"completion_tokens": completion_token_count,
|
||||||
|
"total_tokens": token_count + completion_token_count
|
||||||
|
}
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
def chat_completions(body: dict, is_legacy: bool = False, stop_event=None) -> dict:
|
||||||
|
generator = chat_completions_common(body, is_legacy, stream=False, stop_event=stop_event)
|
||||||
|
return deque(generator, maxlen=1).pop()
|
||||||
|
|
||||||
|
|
||||||
|
def stream_chat_completions(body: dict, is_legacy: bool = False, stop_event=None):
|
||||||
|
for resp in chat_completions_common(body, is_legacy, stream=True, stop_event=stop_event):
|
||||||
|
yield resp
|
||||||
|
|
||||||
|
|
||||||
|
def completions(body: dict, is_legacy: bool = False, stop_event=None) -> dict:
|
||||||
|
generator = completions_common(body, is_legacy, stream=False, stop_event=stop_event)
|
||||||
|
return deque(generator, maxlen=1).pop()
|
||||||
|
|
||||||
|
|
||||||
|
def stream_completions(body: dict, is_legacy: bool = False, stop_event=None):
|
||||||
|
for resp in completions_common(body, is_legacy, stream=True, stop_event=stop_event):
|
||||||
|
yield resp
|
||||||
|
|
||||||
|
|
||||||
|
def validateTools(tools: list[dict]):
|
||||||
|
# Validate each tool definition in the JSON array
|
||||||
|
valid_tools = None
|
||||||
|
for idx in range(len(tools)):
|
||||||
|
tool = tools[idx]
|
||||||
|
try:
|
||||||
|
tool_definition = ToolDefinition(**tool)
|
||||||
|
# Backfill defaults so Jinja2 templates don't crash on missing fields
|
||||||
|
func = tool.get("function", {})
|
||||||
|
if "description" not in func:
|
||||||
|
func["description"] = ""
|
||||||
|
if "parameters" not in func:
|
||||||
|
func["parameters"] = {"type": "object", "properties": {}}
|
||||||
|
if valid_tools is None:
|
||||||
|
valid_tools = []
|
||||||
|
valid_tools.append(tool)
|
||||||
|
except ValidationError:
|
||||||
|
raise InvalidRequestError(message=f"Invalid tool specification at index {idx}.", param='tools')
|
||||||
|
|
||||||
|
return valid_tools
|
||||||
|
|
@ -3,10 +3,9 @@ import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from transformers import AutoModel
|
from transformers import AutoModel
|
||||||
|
|
||||||
from .errors import ServiceUnavailableError
|
from extensions.openai.errors import ServiceUnavailableError
|
||||||
from .utils import debug_msg, float_list_to_base64
|
from extensions.openai.utils import debug_msg, float_list_to_base64
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules import shared
|
|
||||||
|
|
||||||
embeddings_params_initialized = False
|
embeddings_params_initialized = False
|
||||||
|
|
||||||
|
|
@ -18,12 +17,14 @@ def initialize_embedding_params():
|
||||||
'''
|
'''
|
||||||
global embeddings_params_initialized
|
global embeddings_params_initialized
|
||||||
if not embeddings_params_initialized:
|
if not embeddings_params_initialized:
|
||||||
|
from extensions.openai.script import params
|
||||||
|
|
||||||
global st_model, embeddings_model, embeddings_device
|
global st_model, embeddings_model, embeddings_device
|
||||||
|
|
||||||
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", 'sentence-transformers/all-mpnet-base-v2')
|
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", params.get('embedding_model', 'all-mpnet-base-v2'))
|
||||||
embeddings_model = None
|
embeddings_model = None
|
||||||
# OPENEDAI_EMBEDDING_DEVICE: auto (best or cpu), cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone
|
# OPENEDAI_EMBEDDING_DEVICE: auto (best or cpu), cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone
|
||||||
embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", 'cpu')
|
embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", params.get('embedding_device', 'cpu'))
|
||||||
if embeddings_device.lower() == 'auto':
|
if embeddings_device.lower() == 'auto':
|
||||||
embeddings_device = None
|
embeddings_device = None
|
||||||
|
|
||||||
|
|
@ -40,14 +41,14 @@ def load_embedding_model(model: str):
|
||||||
initialize_embedding_params()
|
initialize_embedding_params()
|
||||||
global embeddings_device, embeddings_model
|
global embeddings_device, embeddings_model
|
||||||
try:
|
try:
|
||||||
logger.info(f"Try embedding model: {model} on {embeddings_device}")
|
print(f"Try embedding model: {model} on {embeddings_device}")
|
||||||
if 'jina-embeddings' in model:
|
if 'jina-embeddings' in model:
|
||||||
embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=shared.args.trust_remote_code)
|
embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=True) # trust_remote_code is needed to use the encode method
|
||||||
embeddings_model = embeddings_model.to(embeddings_device)
|
embeddings_model = embeddings_model.to(embeddings_device)
|
||||||
else:
|
else:
|
||||||
embeddings_model = SentenceTransformer(model, device=embeddings_device)
|
embeddings_model = SentenceTransformer(model, device=embeddings_device)
|
||||||
|
|
||||||
logger.info(f"Loaded embedding model: {model}")
|
print(f"Loaded embedding model: {model}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
embeddings_model = None
|
embeddings_model = None
|
||||||
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e))
|
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e))
|
||||||
|
|
@ -4,12 +4,9 @@ OpenAI-compatible image generation using local diffusion models.
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import json
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from extensions.openai.errors import ServiceUnavailableError
|
||||||
|
|
||||||
from .errors import ServiceUnavailableError
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -18,7 +15,7 @@ def generations(request):
|
||||||
Generate images using the loaded diffusion model.
|
Generate images using the loaded diffusion model.
|
||||||
Returns dict with 'created' timestamp and 'data' list of images.
|
Returns dict with 'created' timestamp and 'data' list of images.
|
||||||
"""
|
"""
|
||||||
from modules.ui_image_generation import build_generation_metadata, generate
|
from modules.ui_image_generation import generate
|
||||||
|
|
||||||
if shared.image_model is None:
|
if shared.image_model is None:
|
||||||
raise ServiceUnavailableError("No image model loaded. Load a model via the UI first.")
|
raise ServiceUnavailableError("No image model loaded. Load a model via the UI first.")
|
||||||
|
|
@ -49,18 +46,10 @@ def generations(request):
|
||||||
if not images:
|
if not images:
|
||||||
raise ServiceUnavailableError("Image generation failed or produced no images.")
|
raise ServiceUnavailableError("Image generation failed or produced no images.")
|
||||||
|
|
||||||
# Build response with per-batch metadata (seed increments per batch)
|
# Build response
|
||||||
base_seed = state.get('image_seed_resolved', state['image_seed'])
|
|
||||||
batch_size = int(state['image_batch_size'])
|
|
||||||
|
|
||||||
resp = {'created': int(time.time()), 'data': []}
|
resp = {'created': int(time.time()), 'data': []}
|
||||||
for idx, img in enumerate(images):
|
for img in images:
|
||||||
batch_seed = base_seed + idx // batch_size
|
b64 = _image_to_base64(img)
|
||||||
metadata = build_generation_metadata(state, batch_seed)
|
|
||||||
metadata_json = json.dumps(metadata, ensure_ascii=False)
|
|
||||||
png_info = PngInfo()
|
|
||||||
png_info.add_text("image_gen_settings", metadata_json)
|
|
||||||
b64 = _image_to_base64(img, png_info)
|
|
||||||
|
|
||||||
image_obj = {'revised_prompt': request.prompt}
|
image_obj = {'revised_prompt': request.prompt}
|
||||||
|
|
||||||
|
|
@ -74,7 +63,7 @@ def generations(request):
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
def _image_to_base64(image, png_info=None) -> str:
|
def _image_to_base64(image) -> str:
|
||||||
buffered = io.BytesIO()
|
buffered = io.BytesIO()
|
||||||
image.save(buffered, format="PNG", pnginfo=png_info)
|
image.save(buffered, format="PNG")
|
||||||
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from .completions import process_parameters
|
from extensions.openai.completions import process_parameters
|
||||||
from modules.logits import get_next_logits
|
from modules.logits import get_next_logits
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
from modules import loaders, shared
|
from modules import shared, ui
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.LoRA import add_lora_to_model
|
from modules.LoRA import add_lora_to_model
|
||||||
from modules.models import load_model, unload_model
|
from modules.models import load_model, unload_model
|
||||||
from modules.models_settings import get_model_metadata, load_instruction_template, update_model_parameters
|
from modules.models_settings import get_model_metadata, update_model_parameters
|
||||||
from modules.utils import get_available_loras, get_available_models
|
from modules.utils import get_available_loras, get_available_models
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -20,14 +20,10 @@ def list_models():
|
||||||
|
|
||||||
def list_models_openai_format():
|
def list_models_openai_format():
|
||||||
"""Returns model list in OpenAI API format"""
|
"""Returns model list in OpenAI API format"""
|
||||||
if shared.model_name and shared.model_name != 'None':
|
model_names = get_available_models()
|
||||||
data = [model_info_dict(shared.model_name)]
|
|
||||||
else:
|
|
||||||
data = []
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": data
|
"data": [model_info_dict(name) for name in model_names]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -42,27 +38,19 @@ def model_info_dict(model_name: str) -> dict:
|
||||||
|
|
||||||
def _load_model(data):
|
def _load_model(data):
|
||||||
model_name = data["model_name"]
|
model_name = data["model_name"]
|
||||||
args = data.get("args")
|
args = data["args"]
|
||||||
|
settings = data["settings"]
|
||||||
|
|
||||||
unload_model()
|
unload_model()
|
||||||
model_settings = get_model_metadata(model_name)
|
model_settings = get_model_metadata(model_name)
|
||||||
|
update_model_parameters(model_settings)
|
||||||
|
|
||||||
# Update shared.args with custom model loading settings
|
# Update shared.args with custom model loading settings
|
||||||
# Security: only allow keys that correspond to model loading
|
# Security: only allow keys that correspond to model loading
|
||||||
# parameters exposed in the UI. Never allow security-sensitive
|
# parameters exposed in the UI. Never allow security-sensitive
|
||||||
# flags like trust_remote_code or extra_flags to be set via the API.
|
# flags like trust_remote_code or extra_flags to be set via the API.
|
||||||
blocked_keys = {'extra_flags'}
|
blocked_keys = {'extra_flags'}
|
||||||
allowed_keys = set(loaders.list_model_elements()) - blocked_keys
|
allowed_keys = set(ui.list_model_elements()) - blocked_keys
|
||||||
|
|
||||||
# Reset all loader args to their startup values before applying new ones,
|
|
||||||
# so settings from a previous API load don't leak into this one.
|
|
||||||
# Include blocked keys in the reset (safe: restores startup value, not API-controlled).
|
|
||||||
for k in allowed_keys | blocked_keys:
|
|
||||||
if hasattr(shared.args, k) and hasattr(shared.original_args, k):
|
|
||||||
setattr(shared.args, k, getattr(shared.original_args, k))
|
|
||||||
|
|
||||||
update_model_parameters(model_settings)
|
|
||||||
|
|
||||||
if args:
|
if args:
|
||||||
for k in args:
|
for k in args:
|
||||||
if k in allowed_keys and hasattr(shared.args, k):
|
if k in allowed_keys and hasattr(shared.args, k):
|
||||||
|
|
@ -70,12 +58,15 @@ def _load_model(data):
|
||||||
|
|
||||||
shared.model, shared.tokenizer = load_model(model_name)
|
shared.model, shared.tokenizer = load_model(model_name)
|
||||||
|
|
||||||
if data.get("instruction_template_str") is not None:
|
# Update shared.settings with custom generation defaults
|
||||||
shared.settings['instruction_template_str'] = data["instruction_template_str"]
|
if settings:
|
||||||
logger.info("INSTRUCTION TEMPLATE: set to custom Jinja2 string")
|
for k in settings:
|
||||||
elif data.get("instruction_template") is not None:
|
if k in shared.settings:
|
||||||
shared.settings['instruction_template_str'] = load_instruction_template(data["instruction_template"])
|
shared.settings[k] = settings[k]
|
||||||
logger.info(f"INSTRUCTION TEMPLATE: {data['instruction_template']}")
|
if k == 'truncation_length':
|
||||||
|
logger.info(f"TRUNCATION LENGTH (UPDATED): {shared.settings['truncation_length']}")
|
||||||
|
elif k == 'instruction_template':
|
||||||
|
logger.info(f"INSTRUCTION TEMPLATE (UPDATED): {shared.settings['instruction_template']}")
|
||||||
|
|
||||||
|
|
||||||
def list_loras():
|
def list_loras():
|
||||||
|
|
@ -3,7 +3,7 @@ import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.linalg import norm
|
from numpy.linalg import norm
|
||||||
|
|
||||||
from .embeddings import get_embeddings
|
from extensions.openai.embeddings import get_embeddings
|
||||||
|
|
||||||
moderations_disabled = False # return 0/false
|
moderations_disabled = False # return 0/false
|
||||||
category_embeddings = None
|
category_embeddings = None
|
||||||
|
|
@ -64,4 +64,6 @@ def moderations(input):
|
||||||
'category_scores': category_scores,
|
'category_scores': category_scores,
|
||||||
}])
|
}])
|
||||||
|
|
||||||
|
print(results)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
@ -10,27 +10,24 @@ from threading import Thread
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import Depends, FastAPI, Header, HTTPException
|
from fastapi import Depends, FastAPI, Header, HTTPException
|
||||||
from fastapi.exceptions import RequestValidationError
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.requests import Request
|
from fastapi.requests import Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
from pydub import AudioSegment
|
||||||
from sse_starlette import EventSourceResponse
|
from sse_starlette import EventSourceResponse
|
||||||
from starlette.concurrency import iterate_in_threadpool
|
from starlette.concurrency import iterate_in_threadpool
|
||||||
|
|
||||||
import modules.api.completions as OAIcompletions
|
import extensions.openai.completions as OAIcompletions
|
||||||
import modules.api.logits as OAIlogits
|
import extensions.openai.logits as OAIlogits
|
||||||
import modules.api.models as OAImodels
|
import extensions.openai.models as OAImodels
|
||||||
import modules.api.anthropic as Anthropic
|
from extensions.openai.tokens import token_count, token_decode, token_encode
|
||||||
from .tokens import token_count, token_decode, token_encode
|
from extensions.openai.utils import _start_cloudflared
|
||||||
from .errors import OpenAIError
|
|
||||||
from .utils import _start_cloudflared
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.models import unload_model
|
from modules.models import unload_model
|
||||||
from modules.text_generation import stop_everything_event # used by /v1/internal/stop-generation
|
from modules.text_generation import stop_everything_event # used by /v1/internal/stop-generation
|
||||||
|
|
||||||
from .typing import (
|
from .typing import (
|
||||||
AnthropicRequest,
|
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatPromptResponse,
|
ChatPromptResponse,
|
||||||
|
|
@ -55,6 +52,12 @@ from .typing import (
|
||||||
to_dict
|
to_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'embedding_device': 'cpu',
|
||||||
|
'embedding_model': 'sentence-transformers/all-mpnet-base-v2',
|
||||||
|
'debug': 0
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def _wait_for_disconnect(request: Request, stop_event: threading.Event):
|
async def _wait_for_disconnect(request: Request, stop_event: threading.Event):
|
||||||
"""Block until the client disconnects, then signal the stop_event."""
|
"""Block until the client disconnects, then signal the stop_event."""
|
||||||
|
|
@ -77,23 +80,9 @@ def verify_admin_key(authorization: str = Header(None)) -> None:
|
||||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
|
|
||||||
|
|
||||||
def verify_anthropic_key(x_api_key: str = Header(None, alias="x-api-key")) -> None:
|
|
||||||
expected_api_key = shared.args.api_key
|
|
||||||
if expected_api_key and (x_api_key is None or x_api_key != expected_api_key):
|
|
||||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
||||||
|
|
||||||
|
|
||||||
class AnthropicError(Exception):
|
|
||||||
def __init__(self, message: str, error_type: str = "invalid_request_error", status_code: int = 400):
|
|
||||||
self.message = message
|
|
||||||
self.error_type = error_type
|
|
||||||
self.status_code = status_code
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
check_key = [Depends(verify_api_key)]
|
check_key = [Depends(verify_api_key)]
|
||||||
check_admin_key = [Depends(verify_admin_key)]
|
check_admin_key = [Depends(verify_admin_key)]
|
||||||
check_anthropic_key = [Depends(verify_anthropic_key)]
|
|
||||||
|
|
||||||
# Configure CORS settings to allow all origins, methods, and headers
|
# Configure CORS settings to allow all origins, methods, and headers
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
|
|
@ -105,42 +94,6 @@ app.add_middleware(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.exception_handler(OpenAIError)
|
|
||||||
async def openai_error_handler(request: Request, exc: OpenAIError):
|
|
||||||
error_type = "server_error" if exc.code >= 500 else "invalid_request_error"
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=exc.code,
|
|
||||||
content={"error": {
|
|
||||||
"message": exc.message,
|
|
||||||
"type": error_type,
|
|
||||||
"param": getattr(exc, 'param', None),
|
|
||||||
"code": None
|
|
||||||
}}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.exception_handler(AnthropicError)
|
|
||||||
async def anthropic_error_handler(request: Request, exc: AnthropicError):
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=exc.status_code,
|
|
||||||
content={"type": "error", "error": {"type": exc.error_type, "message": exc.message}}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.exception_handler(RequestValidationError)
|
|
||||||
async def validation_error_handler(request: Request, exc: RequestValidationError):
|
|
||||||
if request.url.path.startswith("/v1/messages"):
|
|
||||||
messages = "; ".join(
|
|
||||||
f"{'.'.join(str(l) for l in e['loc'])}: {e['msg']}" for e in exc.errors()
|
|
||||||
)
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=400,
|
|
||||||
content={"type": "error", "error": {"type": "invalid_request_error", "message": messages}}
|
|
||||||
)
|
|
||||||
|
|
||||||
return JSONResponse(status_code=422, content={"detail": exc.errors()})
|
|
||||||
|
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def validate_host_header(request: Request, call_next):
|
async def validate_host_header(request: Request, call_next):
|
||||||
# Be strict about only approving access to localhost by default
|
# Be strict about only approving access to localhost by default
|
||||||
|
|
@ -166,12 +119,6 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
|
||||||
is_legacy = "/generate" in path
|
is_legacy = "/generate" in path
|
||||||
|
|
||||||
if request_data.stream:
|
if request_data.stream:
|
||||||
if (request_data.n or 1) > 1:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=400,
|
|
||||||
content={"error": {"message": "n > 1 is not supported with streaming.", "type": "invalid_request_error", "param": "n", "code": None}}
|
|
||||||
)
|
|
||||||
|
|
||||||
stop_event = threading.Event()
|
stop_event = threading.Event()
|
||||||
|
|
||||||
async def generator():
|
async def generator():
|
||||||
|
|
@ -183,8 +130,6 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
|
||||||
break
|
break
|
||||||
|
|
||||||
yield {"data": json.dumps(resp)}
|
yield {"data": json.dumps(resp)}
|
||||||
|
|
||||||
yield {"data": "[DONE]"}
|
|
||||||
finally:
|
finally:
|
||||||
stop_event.set()
|
stop_event.set()
|
||||||
response.close()
|
response.close()
|
||||||
|
|
@ -225,8 +170,6 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
|
||||||
break
|
break
|
||||||
|
|
||||||
yield {"data": json.dumps(resp)}
|
yield {"data": json.dumps(resp)}
|
||||||
|
|
||||||
yield {"data": "[DONE]"}
|
|
||||||
finally:
|
finally:
|
||||||
stop_event.set()
|
stop_event.set()
|
||||||
response.close()
|
response.close()
|
||||||
|
|
@ -250,76 +193,6 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@app.post('/v1/messages', dependencies=check_anthropic_key)
|
|
||||||
async def anthropic_messages(request: Request, request_data: AnthropicRequest):
|
|
||||||
body = to_dict(request_data)
|
|
||||||
model = body.get('model') or shared.model_name or 'unknown'
|
|
||||||
|
|
||||||
try:
|
|
||||||
converted = Anthropic.convert_request(body)
|
|
||||||
except Exception as e:
|
|
||||||
raise AnthropicError(message=str(e))
|
|
||||||
|
|
||||||
try:
|
|
||||||
return await _anthropic_generate(request, request_data, converted, model)
|
|
||||||
except OpenAIError as e:
|
|
||||||
error_type = "invalid_request_error" if e.code < 500 else "api_error"
|
|
||||||
if e.code == 503:
|
|
||||||
error_type = "overloaded_error"
|
|
||||||
raise AnthropicError(message=e.message, error_type=error_type, status_code=e.code)
|
|
||||||
except Exception as e:
|
|
||||||
raise AnthropicError(message=str(e) or "Internal server error", error_type="api_error", status_code=500)
|
|
||||||
|
|
||||||
|
|
||||||
async def _anthropic_generate(request, request_data, converted, model):
|
|
||||||
if request_data.stream:
|
|
||||||
stop_event = threading.Event()
|
|
||||||
|
|
||||||
async def generator():
|
|
||||||
converter = Anthropic.StreamConverter(model)
|
|
||||||
response = OAIcompletions.stream_chat_completions(converted, is_legacy=False, stop_event=stop_event)
|
|
||||||
try:
|
|
||||||
async for resp in iterate_in_threadpool(response):
|
|
||||||
disconnected = await request.is_disconnected()
|
|
||||||
if disconnected:
|
|
||||||
break
|
|
||||||
|
|
||||||
for event in converter.process_chunk(resp):
|
|
||||||
yield event
|
|
||||||
|
|
||||||
for event in converter.finish():
|
|
||||||
yield event
|
|
||||||
except OpenAIError as e:
|
|
||||||
error_type = "invalid_request_error" if e.code < 500 else "api_error"
|
|
||||||
if e.code == 503:
|
|
||||||
error_type = "overloaded_error"
|
|
||||||
yield {
|
|
||||||
"event": "error",
|
|
||||||
"data": json.dumps({"type": "error", "error": {"type": error_type, "message": e.message}})
|
|
||||||
}
|
|
||||||
finally:
|
|
||||||
stop_event.set()
|
|
||||||
response.close()
|
|
||||||
|
|
||||||
return EventSourceResponse(generator(), sep="\n")
|
|
||||||
|
|
||||||
else:
|
|
||||||
stop_event = threading.Event()
|
|
||||||
monitor = asyncio.create_task(_wait_for_disconnect(request, stop_event))
|
|
||||||
try:
|
|
||||||
openai_resp = await asyncio.to_thread(
|
|
||||||
OAIcompletions.chat_completions,
|
|
||||||
converted,
|
|
||||||
is_legacy=False,
|
|
||||||
stop_event=stop_event
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
stop_event.set()
|
|
||||||
monitor.cancel()
|
|
||||||
|
|
||||||
return JSONResponse(Anthropic.build_response(openai_resp, model))
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models", dependencies=check_key)
|
@app.get("/v1/models", dependencies=check_key)
|
||||||
@app.get("/v1/models/{model}", dependencies=check_key)
|
@app.get("/v1/models/{model}", dependencies=check_key)
|
||||||
async def handle_models(request: Request):
|
async def handle_models(request: Request):
|
||||||
|
|
@ -346,7 +219,6 @@ def handle_billing_usage():
|
||||||
@app.post('/v1/audio/transcriptions', dependencies=check_key)
|
@app.post('/v1/audio/transcriptions', dependencies=check_key)
|
||||||
async def handle_audio_transcription(request: Request):
|
async def handle_audio_transcription(request: Request):
|
||||||
import speech_recognition as sr
|
import speech_recognition as sr
|
||||||
from pydub import AudioSegment
|
|
||||||
|
|
||||||
r = sr.Recognizer()
|
r = sr.Recognizer()
|
||||||
|
|
||||||
|
|
@ -378,7 +250,7 @@ async def handle_audio_transcription(request: Request):
|
||||||
|
|
||||||
@app.post('/v1/images/generations', response_model=ImageGenerationResponse, dependencies=check_key)
|
@app.post('/v1/images/generations', response_model=ImageGenerationResponse, dependencies=check_key)
|
||||||
async def handle_image_generation(request_data: ImageGenerationRequest):
|
async def handle_image_generation(request_data: ImageGenerationRequest):
|
||||||
import modules.api.images as OAIimages
|
import extensions.openai.images as OAIimages
|
||||||
|
|
||||||
response = await asyncio.to_thread(OAIimages.generations, request_data)
|
response = await asyncio.to_thread(OAIimages.generations, request_data)
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
@ -386,7 +258,7 @@ async def handle_image_generation(request_data: ImageGenerationRequest):
|
||||||
|
|
||||||
@app.post("/v1/embeddings", response_model=EmbeddingsResponse, dependencies=check_key)
|
@app.post("/v1/embeddings", response_model=EmbeddingsResponse, dependencies=check_key)
|
||||||
async def handle_embeddings(request: Request, request_data: EmbeddingsRequest):
|
async def handle_embeddings(request: Request, request_data: EmbeddingsRequest):
|
||||||
import modules.api.embeddings as OAIembeddings
|
import extensions.openai.embeddings as OAIembeddings
|
||||||
|
|
||||||
input = request_data.input
|
input = request_data.input
|
||||||
if not input:
|
if not input:
|
||||||
|
|
@ -401,7 +273,7 @@ async def handle_embeddings(request: Request, request_data: EmbeddingsRequest):
|
||||||
|
|
||||||
@app.post("/v1/moderations", dependencies=check_key)
|
@app.post("/v1/moderations", dependencies=check_key)
|
||||||
async def handle_moderations(request: Request):
|
async def handle_moderations(request: Request):
|
||||||
import modules.api.moderations as OAImoderations
|
import extensions.openai.moderations as OAImoderations
|
||||||
|
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
input = body["input"]
|
input = body["input"]
|
||||||
|
|
@ -475,8 +347,10 @@ async def handle_list_models():
|
||||||
@app.post("/v1/internal/model/load", dependencies=check_admin_key)
|
@app.post("/v1/internal/model/load", dependencies=check_admin_key)
|
||||||
async def handle_load_model(request_data: LoadModelRequest):
|
async def handle_load_model(request_data: LoadModelRequest):
|
||||||
'''
|
'''
|
||||||
The "args" parameter can be used to modify loader flags before loading
|
This endpoint is experimental and may change in the future.
|
||||||
a model. Example:
|
|
||||||
|
The "args" parameter can be used to modify flags like "--load-in-4bit"
|
||||||
|
or "--n-gpu-layers" before loading a model. Example:
|
||||||
|
|
||||||
```
|
```
|
||||||
"args": {
|
"args": {
|
||||||
|
|
@ -485,13 +359,18 @@ async def handle_load_model(request_data: LoadModelRequest):
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Loader args are reset to their startup defaults between loads, so
|
Note that those settings will remain after loading the model. So you
|
||||||
settings from a previous load do not leak into the next one.
|
may need to change them back to load a second model.
|
||||||
|
|
||||||
The "instruction_template" parameter sets the default instruction
|
The "settings" parameter is also a dict but with keys for the
|
||||||
template by name (from user_data/instruction-templates/). The
|
shared.settings object. It can be used to modify the default instruction
|
||||||
"instruction_template_str" parameter sets it as a raw Jinja2 string
|
template like this:
|
||||||
and takes precedence over "instruction_template".
|
|
||||||
|
```
|
||||||
|
"settings": {
|
||||||
|
"instruction_template": "Alpaca"
|
||||||
|
}
|
||||||
|
```
|
||||||
'''
|
'''
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -499,17 +378,12 @@ async def handle_load_model(request_data: LoadModelRequest):
|
||||||
return JSONResponse(content="OK")
|
return JSONResponse(content="OK")
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise HTTPException(status_code=500, detail="Failed to load the model.")
|
raise HTTPException(status_code=400, detail="Failed to load the model.")
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/internal/model/unload", dependencies=check_admin_key)
|
@app.post("/v1/internal/model/unload", dependencies=check_admin_key)
|
||||||
async def handle_unload_model():
|
async def handle_unload_model():
|
||||||
try:
|
|
||||||
unload_model()
|
unload_model()
|
||||||
return JSONResponse(content="OK")
|
|
||||||
except Exception:
|
|
||||||
traceback.print_exc()
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to unload the model.")
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/internal/lora/list", response_model=LoraListResponse, dependencies=check_admin_key)
|
@app.get("/v1/internal/lora/list", response_model=LoraListResponse, dependencies=check_admin_key)
|
||||||
|
|
@ -537,8 +411,8 @@ async def handle_unload_loras():
|
||||||
def find_available_port(starting_port):
|
def find_available_port(starting_port):
|
||||||
"""Try the starting port, then find an available one if it's taken."""
|
"""Try the starting port, then find an available one if it's taken."""
|
||||||
try:
|
try:
|
||||||
|
# Try to create a socket with the starting port
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
||||||
s.bind(('', starting_port))
|
s.bind(('', starting_port))
|
||||||
return starting_port
|
return starting_port
|
||||||
except OSError:
|
except OSError:
|
||||||
|
|
@ -559,11 +433,8 @@ def run_server():
|
||||||
|
|
||||||
# In the server configuration:
|
# In the server configuration:
|
||||||
server_addrs = []
|
server_addrs = []
|
||||||
if shared.args.listen and shared.args.listen_host:
|
|
||||||
server_addrs.append(shared.args.listen_host)
|
|
||||||
else:
|
|
||||||
if os.environ.get('OPENEDAI_ENABLE_IPV6', shared.args.api_enable_ipv6):
|
if os.environ.get('OPENEDAI_ENABLE_IPV6', shared.args.api_enable_ipv6):
|
||||||
server_addrs.append('::' if shared.args.listen else '::1')
|
server_addrs.append('[::]' if shared.args.listen else '[::1]')
|
||||||
if not os.environ.get('OPENEDAI_DISABLE_IPV4', shared.args.api_disable_ipv4):
|
if not os.environ.get('OPENEDAI_DISABLE_IPV4', shared.args.api_disable_ipv4):
|
||||||
server_addrs.append('0.0.0.0' if shared.args.listen else '127.0.0.1')
|
server_addrs.append('0.0.0.0' if shared.args.listen else '127.0.0.1')
|
||||||
|
|
||||||
|
|
@ -576,15 +447,15 @@ def run_server():
|
||||||
port,
|
port,
|
||||||
shared.args.public_api_id,
|
shared.args.public_api_id,
|
||||||
max_attempts=3,
|
max_attempts=3,
|
||||||
on_start=lambda url: logger.info(f'OpenAI/Anthropic-compatible API URL:\n\n{url}/v1\n')
|
on_start=lambda url: logger.info(f'OpenAI-compatible API URL:\n\n{url}\n')
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
url_proto = 'https://' if (ssl_certfile and ssl_keyfile) else 'http://'
|
url_proto = 'https://' if (ssl_certfile and ssl_keyfile) else 'http://'
|
||||||
urls = [f'{url_proto}[{addr}]:{port}/v1' if ':' in addr else f'{url_proto}{addr}:{port}/v1' for addr in server_addrs]
|
urls = [f'{url_proto}{addr}:{port}' for addr in server_addrs]
|
||||||
if len(urls) > 1:
|
if len(urls) > 1:
|
||||||
logger.info('OpenAI/Anthropic-compatible API URLs:\n\n' + '\n'.join(urls) + '\n')
|
logger.info('OpenAI-compatible API URLs:\n\n' + '\n'.join(urls) + '\n')
|
||||||
else:
|
else:
|
||||||
logger.info('OpenAI/Anthropic-compatible API URL:\n\n' + '\n'.join(urls) + '\n')
|
logger.info('OpenAI-compatible API URL:\n\n' + '\n'.join(urls) + '\n')
|
||||||
|
|
||||||
# Log API keys
|
# Log API keys
|
||||||
if shared.args.api_key:
|
if shared.args.api_key:
|
||||||
|
|
@ -601,15 +472,7 @@ def run_server():
|
||||||
uvicorn.run(app, host=server_addrs, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile, access_log=False)
|
uvicorn.run(app, host=server_addrs, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile, access_log=False)
|
||||||
|
|
||||||
|
|
||||||
_server_started = False
|
|
||||||
|
|
||||||
|
|
||||||
def setup():
|
def setup():
|
||||||
global _server_started
|
|
||||||
if _server_started:
|
|
||||||
return
|
|
||||||
|
|
||||||
_server_started = True
|
|
||||||
if shared.args.nowebui:
|
if shared.args.nowebui:
|
||||||
run_server()
|
run_server()
|
||||||
else:
|
else:
|
||||||
|
|
@ -99,10 +99,6 @@ class ToolCall(BaseModel):
|
||||||
function: FunctionCall
|
function: FunctionCall
|
||||||
|
|
||||||
|
|
||||||
class StreamOptions(BaseModel):
|
|
||||||
include_usage: bool | None = False
|
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequestParams(BaseModel):
|
class CompletionRequestParams(BaseModel):
|
||||||
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
|
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
|
||||||
prompt: str | List[str] | None = Field(default=None, description="Text prompt for completion. Can also use 'messages' format for multimodal.")
|
prompt: str | List[str] | None = Field(default=None, description="Text prompt for completion. Can also use 'messages' format for multimodal.")
|
||||||
|
|
@ -113,11 +109,10 @@ class CompletionRequestParams(BaseModel):
|
||||||
logit_bias: dict | None = None
|
logit_bias: dict | None = None
|
||||||
logprobs: int | None = None
|
logprobs: int | None = None
|
||||||
max_tokens: int | None = 512
|
max_tokens: int | None = 512
|
||||||
n: int | None = Field(default=1, description="Number of completions to generate. Only supported without streaming.")
|
n: int | None = Field(default=1, description="Unused parameter.")
|
||||||
presence_penalty: float | None = shared.args.presence_penalty
|
presence_penalty: float | None = shared.args.presence_penalty
|
||||||
stop: str | List[str] | None = None
|
stop: str | List[str] | None = None
|
||||||
stream: bool | None = False
|
stream: bool | None = False
|
||||||
stream_options: StreamOptions | None = None
|
|
||||||
suffix: str | None = None
|
suffix: str | None = None
|
||||||
temperature: float | None = shared.args.temperature
|
temperature: float | None = shared.args.temperature
|
||||||
top_p: float | None = shared.args.top_p
|
top_p: float | None = shared.args.top_p
|
||||||
|
|
@ -144,33 +139,22 @@ class CompletionResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequestParams(BaseModel):
|
class ChatCompletionRequestParams(BaseModel):
|
||||||
messages: List[dict] = Field(..., min_length=1)
|
messages: List[dict]
|
||||||
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
|
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
|
||||||
frequency_penalty: float | None = shared.args.frequency_penalty
|
frequency_penalty: float | None = shared.args.frequency_penalty
|
||||||
function_call: str | dict | None = Field(default=None, description="Unused parameter.")
|
function_call: str | dict | None = Field(default=None, description="Unused parameter.")
|
||||||
functions: List[dict] | None = Field(default=None, description="Unused parameter.")
|
functions: List[dict] | None = Field(default=None, description="Unused parameter.")
|
||||||
tools: List[dict] | None = Field(default=None, description="Tools signatures passed via MCP.")
|
tools: List[dict] | None = Field(default=None, description="Tools signatures passed via MCP.")
|
||||||
tool_choice: str | dict | None = Field(default=None, description="Controls tool use: 'auto', 'none', 'required', or {\"type\": \"function\", \"function\": {\"name\": \"...\"}}.")
|
|
||||||
logit_bias: dict | None = None
|
logit_bias: dict | None = None
|
||||||
logprobs: bool | None = None
|
|
||||||
top_logprobs: int | None = None
|
|
||||||
max_tokens: int | None = None
|
max_tokens: int | None = None
|
||||||
max_completion_tokens: int | None = None
|
|
||||||
n: int | None = Field(default=1, description="Unused parameter.")
|
n: int | None = Field(default=1, description="Unused parameter.")
|
||||||
presence_penalty: float | None = shared.args.presence_penalty
|
presence_penalty: float | None = shared.args.presence_penalty
|
||||||
stop: str | List[str] | None = None
|
stop: str | List[str] | None = None
|
||||||
stream: bool | None = False
|
stream: bool | None = False
|
||||||
stream_options: StreamOptions | None = None
|
|
||||||
temperature: float | None = shared.args.temperature
|
temperature: float | None = shared.args.temperature
|
||||||
top_p: float | None = shared.args.top_p
|
top_p: float | None = shared.args.top_p
|
||||||
user: str | None = Field(default=None, description="Unused parameter.")
|
user: str | None = Field(default=None, description="Unused parameter.")
|
||||||
|
|
||||||
@model_validator(mode='after')
|
|
||||||
def resolve_max_tokens(self):
|
|
||||||
if self.max_tokens is None and self.max_completion_tokens is not None:
|
|
||||||
self.max_tokens = self.max_completion_tokens
|
|
||||||
return self
|
|
||||||
|
|
||||||
mode: str = Field(default='instruct', description="Valid options: instruct, chat, chat-instruct.")
|
mode: str = Field(default='instruct', description="Valid options: instruct, chat, chat-instruct.")
|
||||||
|
|
||||||
instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/user_data/instruction-templates. If not set, the correct template will be automatically obtained from the model metadata.")
|
instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/user_data/instruction-templates. If not set, the correct template will be automatically obtained from the model metadata.")
|
||||||
|
|
@ -271,8 +255,7 @@ class ModelListResponse(BaseModel):
|
||||||
class LoadModelRequest(BaseModel):
|
class LoadModelRequest(BaseModel):
|
||||||
model_name: str
|
model_name: str
|
||||||
args: dict | None = None
|
args: dict | None = None
|
||||||
instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/user_data/instruction-templates. Sets the default template for all subsequent API requests.")
|
settings: dict | None = None
|
||||||
instruction_template_str: str | None = Field(default=None, description="A Jinja2 instruction template string. If set, takes precedence over instruction_template.")
|
|
||||||
|
|
||||||
|
|
||||||
class LoraListResponse(BaseModel):
|
class LoraListResponse(BaseModel):
|
||||||
|
|
@ -283,25 +266,6 @@ class LoadLorasRequest(BaseModel):
|
||||||
lora_names: List[str]
|
lora_names: List[str]
|
||||||
|
|
||||||
|
|
||||||
class AnthropicRequestParams(BaseModel):
|
|
||||||
model: str | None = None
|
|
||||||
messages: List[dict] = Field(..., min_length=1)
|
|
||||||
max_tokens: int
|
|
||||||
system: str | list | None = None
|
|
||||||
temperature: float | None = shared.args.temperature
|
|
||||||
top_p: float | None = shared.args.top_p
|
|
||||||
stop_sequences: list[str] | None = None
|
|
||||||
stream: bool = False
|
|
||||||
tools: list[dict] | None = None
|
|
||||||
tool_choice: dict | None = None
|
|
||||||
thinking: dict | None = None
|
|
||||||
metadata: dict | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class AnthropicRequest(GenerationOptions, AnthropicRequestParams):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ImageGenerationRequest(BaseModel):
|
class ImageGenerationRequest(BaseModel):
|
||||||
"""Image-specific parameters for generation."""
|
"""Image-specific parameters for generation."""
|
||||||
prompt: str
|
prompt: str
|
||||||
527
extensions/openai/utils.py
Normal file
527
extensions/openai/utils.py
Normal file
|
|
@ -0,0 +1,527 @@
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def float_list_to_base64(float_array: np.ndarray) -> str:
|
||||||
|
# Convert the list to a float32 array that the OpenAPI client expects
|
||||||
|
# float_array = np.array(float_list, dtype="float32")
|
||||||
|
|
||||||
|
# Get raw bytes
|
||||||
|
bytes_array = float_array.tobytes()
|
||||||
|
|
||||||
|
# Encode bytes into base64
|
||||||
|
encoded_bytes = base64.b64encode(bytes_array)
|
||||||
|
|
||||||
|
# Turn raw base64 encoded bytes into ASCII
|
||||||
|
ascii_string = encoded_bytes.decode('ascii')
|
||||||
|
return ascii_string
|
||||||
|
|
||||||
|
|
||||||
|
def debug_msg(*args, **kwargs):
|
||||||
|
from extensions.openai.script import params
|
||||||
|
if os.environ.get("OPENEDAI_DEBUG", params.get('debug', 0)):
|
||||||
|
print(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
|
||||||
|
try:
|
||||||
|
from flask_cloudflared import _run_cloudflared
|
||||||
|
except ImportError:
|
||||||
|
print('You should install flask_cloudflared manually')
|
||||||
|
raise Exception(
|
||||||
|
'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.')
|
||||||
|
|
||||||
|
for _ in range(max_attempts):
|
||||||
|
try:
|
||||||
|
if tunnel_id is not None:
|
||||||
|
public_url = _run_cloudflared(port, port + 1, tunnel_id=tunnel_id)
|
||||||
|
else:
|
||||||
|
public_url = _run_cloudflared(port, port + 1)
|
||||||
|
|
||||||
|
if on_start:
|
||||||
|
on_start(public_url)
|
||||||
|
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
|
time.sleep(3)
|
||||||
|
|
||||||
|
raise Exception('Could not start cloudflared.')
|
||||||
|
|
||||||
|
|
||||||
|
def getToolCallId() -> str:
|
||||||
|
letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||||
|
b = [random.choice(letter_bytes) for _ in range(8)]
|
||||||
|
return "call_" + "".join(b).lower()
|
||||||
|
|
||||||
|
|
||||||
|
def checkAndSanitizeToolCallCandidate(candidate_dict: dict, tool_names: list[str]):
|
||||||
|
# check if property 'function' exists and is a dictionary, otherwise adapt dict
|
||||||
|
if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str):
|
||||||
|
candidate_dict = {"type": "function", "function": candidate_dict}
|
||||||
|
if 'function' in candidate_dict and isinstance(candidate_dict['function'], str):
|
||||||
|
candidate_dict['name'] = candidate_dict['function']
|
||||||
|
del candidate_dict['function']
|
||||||
|
candidate_dict = {"type": "function", "function": candidate_dict}
|
||||||
|
if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict):
|
||||||
|
# check if 'name' exists within 'function' and is part of known tools
|
||||||
|
if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names:
|
||||||
|
candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value
|
||||||
|
# map property 'parameters' used by some older models to 'arguments'
|
||||||
|
if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]:
|
||||||
|
candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"]
|
||||||
|
del candidate_dict["function"]["parameters"]
|
||||||
|
return candidate_dict
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extractBalancedJson(text: str, start: int) -> str | None:
|
||||||
|
"""Extract a balanced JSON object from text starting at the given position.
|
||||||
|
|
||||||
|
Walks through the string tracking brace depth and string boundaries
|
||||||
|
to correctly handle arbitrary nesting levels.
|
||||||
|
"""
|
||||||
|
if start >= len(text) or text[start] != '{':
|
||||||
|
return None
|
||||||
|
depth = 0
|
||||||
|
in_string = False
|
||||||
|
escape_next = False
|
||||||
|
for i in range(start, len(text)):
|
||||||
|
c = text[i]
|
||||||
|
if escape_next:
|
||||||
|
escape_next = False
|
||||||
|
continue
|
||||||
|
if c == '\\' and in_string:
|
||||||
|
escape_next = True
|
||||||
|
continue
|
||||||
|
if c == '"':
|
||||||
|
in_string = not in_string
|
||||||
|
continue
|
||||||
|
if in_string:
|
||||||
|
continue
|
||||||
|
if c == '{':
|
||||||
|
depth += 1
|
||||||
|
elif c == '}':
|
||||||
|
depth -= 1
|
||||||
|
if depth == 0:
|
||||||
|
return text[start:i + 1]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _parseChannelToolCalls(answer: str, tool_names: list[str]):
|
||||||
|
"""Parse channel-based tool calls used by GPT-OSS and similar models.
|
||||||
|
|
||||||
|
Format:
|
||||||
|
<|channel|>commentary to=functions.func_name <|constrain|>json<|message|>{"arg": "value"}
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
for m in re.finditer(
|
||||||
|
r'<\|channel\|>commentary to=functions\.([^<\s]+)\s*(?:<\|constrain\|>json)?<\|message\|>',
|
||||||
|
answer
|
||||||
|
):
|
||||||
|
func_name = m.group(1).strip()
|
||||||
|
if func_name not in tool_names:
|
||||||
|
continue
|
||||||
|
json_str = _extractBalancedJson(answer, m.end())
|
||||||
|
if json_str is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
arguments = json.loads(json_str)
|
||||||
|
matches.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": func_name,
|
||||||
|
"arguments": arguments
|
||||||
|
}
|
||||||
|
})
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
def _parseBareNameToolCalls(answer: str, tool_names: list[str]):
|
||||||
|
"""Parse bare function-name style tool calls used by Mistral and similar models.
|
||||||
|
|
||||||
|
Format:
|
||||||
|
functionName{"arg": "value"}
|
||||||
|
Multiple calls are concatenated directly or separated by whitespace.
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
# Match tool name followed by opening brace, then extract balanced JSON
|
||||||
|
escaped_names = [re.escape(name) for name in tool_names]
|
||||||
|
pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{'
|
||||||
|
for match in re.finditer(pattern, answer):
|
||||||
|
text = match.group(0)
|
||||||
|
name = None
|
||||||
|
for n in tool_names:
|
||||||
|
if text.startswith(n):
|
||||||
|
name = n
|
||||||
|
break
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
brace_start = match.end() - 1
|
||||||
|
json_str = _extractBalancedJson(answer, brace_start)
|
||||||
|
if json_str is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
arguments = json.loads(json_str)
|
||||||
|
matches.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": name,
|
||||||
|
"arguments": arguments
|
||||||
|
}
|
||||||
|
})
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
def _parseXmlParamToolCalls(answer: str, tool_names: list[str]):
|
||||||
|
"""Parse XML-parameter style tool calls used by Qwen3.5 and similar models.
|
||||||
|
|
||||||
|
Format:
|
||||||
|
<tool_call>
|
||||||
|
<function=function_name>
|
||||||
|
<parameter=param_name>value</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
for tc_match in re.finditer(r'<tool_call>\s*(.*?)\s*</tool_call>', answer, re.DOTALL):
|
||||||
|
tc_content = tc_match.group(1)
|
||||||
|
func_match = re.search(r'<function=([^>]+)>', tc_content)
|
||||||
|
if not func_match:
|
||||||
|
continue
|
||||||
|
func_name = func_match.group(1).strip()
|
||||||
|
if func_name not in tool_names:
|
||||||
|
continue
|
||||||
|
arguments = {}
|
||||||
|
for param_match in re.finditer(r'<parameter=([^>]+)>\s*(.*?)\s*</parameter>', tc_content, re.DOTALL):
|
||||||
|
param_name = param_match.group(1).strip()
|
||||||
|
param_value = param_match.group(2).strip()
|
||||||
|
try:
|
||||||
|
param_value = json.loads(param_value)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
pass # keep as string
|
||||||
|
arguments[param_name] = param_value
|
||||||
|
matches.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": func_name,
|
||||||
|
"arguments": arguments
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
def _parseKimiToolCalls(answer: str, tool_names: list[str]):
|
||||||
|
"""Parse Kimi-K2-style tool calls using pipe-delimited tokens.
|
||||||
|
|
||||||
|
Format:
|
||||||
|
<|tool_calls_section_begin|>
|
||||||
|
<|tool_call_begin|>functions.func_name:index<|tool_call_argument_begin|>{"arg": "value"}<|tool_call_end|>
|
||||||
|
<|tool_calls_section_end|>
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
for m in re.finditer(
|
||||||
|
r'<\|tool_call_begin\|>\s*(?:functions\.)?(\S+?)(?::\d+)?\s*<\|tool_call_argument_begin\|>\s*',
|
||||||
|
answer
|
||||||
|
):
|
||||||
|
func_name = m.group(1).strip()
|
||||||
|
if func_name not in tool_names:
|
||||||
|
continue
|
||||||
|
json_str = _extractBalancedJson(answer, m.end())
|
||||||
|
if json_str is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
arguments = json.loads(json_str)
|
||||||
|
matches.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": func_name,
|
||||||
|
"arguments": arguments
|
||||||
|
}
|
||||||
|
})
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
def _parseMiniMaxToolCalls(answer: str, tool_names: list[str]):
|
||||||
|
"""Parse MiniMax-style tool calls using invoke/parameter XML tags.
|
||||||
|
|
||||||
|
Format:
|
||||||
|
<minimax:tool_call>
|
||||||
|
<invoke name="function_name">
|
||||||
|
<parameter name="param_name">value</parameter>
|
||||||
|
</invoke>
|
||||||
|
</minimax:tool_call>
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
for tc_match in re.finditer(r'<minimax:tool_call>\s*(.*?)\s*</minimax:tool_call>', answer, re.DOTALL):
|
||||||
|
tc_content = tc_match.group(1)
|
||||||
|
# Split on <invoke> to handle multiple parallel calls in one block
|
||||||
|
for invoke_match in re.finditer(r'<invoke\s+name="([^"]+)">(.*?)</invoke>', tc_content, re.DOTALL):
|
||||||
|
func_name = invoke_match.group(1).strip()
|
||||||
|
if func_name not in tool_names:
|
||||||
|
continue
|
||||||
|
invoke_body = invoke_match.group(2)
|
||||||
|
arguments = {}
|
||||||
|
for param_match in re.finditer(r'<parameter\s+name="([^"]+)">\s*(.*?)\s*</parameter>', invoke_body, re.DOTALL):
|
||||||
|
param_name = param_match.group(1).strip()
|
||||||
|
param_value = param_match.group(2).strip()
|
||||||
|
try:
|
||||||
|
param_value = json.loads(param_value)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
pass # keep as string
|
||||||
|
arguments[param_name] = param_value
|
||||||
|
matches.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": func_name,
|
||||||
|
"arguments": arguments
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
def _parseDeepSeekToolCalls(answer: str, tool_names: list[str]):
|
||||||
|
"""Parse DeepSeek-style tool calls using fullwidth Unicode token delimiters.
|
||||||
|
|
||||||
|
Format:
|
||||||
|
<|tool▁calls▁begin|><|tool▁call▁begin|>func_name<|tool▁sep|>{"arg": "value"}<|tool▁call▁end|><|tool▁calls▁end|>
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
for m in re.finditer(
|
||||||
|
r'<|tool▁call▁begin|>\s*(\S+?)\s*<|tool▁sep|>\s*',
|
||||||
|
answer
|
||||||
|
):
|
||||||
|
func_name = m.group(1).strip()
|
||||||
|
if func_name not in tool_names:
|
||||||
|
continue
|
||||||
|
json_str = _extractBalancedJson(answer, m.end())
|
||||||
|
if json_str is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
arguments = json.loads(json_str)
|
||||||
|
matches.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": func_name,
|
||||||
|
"arguments": arguments
|
||||||
|
}
|
||||||
|
})
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
def _parseGlmToolCalls(answer: str, tool_names: list[str]):
|
||||||
|
"""Parse GLM-style tool calls using arg_key/arg_value XML pairs.
|
||||||
|
|
||||||
|
Format:
|
||||||
|
<tool_call>function_name
|
||||||
|
<arg_key>key1</arg_key>
|
||||||
|
<arg_value>value1</arg_value>
|
||||||
|
</tool_call>
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
for tc_match in re.finditer(r'<tool_call>\s*(.*?)\s*</tool_call>', answer, re.DOTALL):
|
||||||
|
tc_content = tc_match.group(1)
|
||||||
|
# First non-tag text is the function name
|
||||||
|
name_match = re.match(r'([^<\s]+)', tc_content.strip())
|
||||||
|
if not name_match:
|
||||||
|
continue
|
||||||
|
func_name = name_match.group(1).strip()
|
||||||
|
if func_name not in tool_names:
|
||||||
|
continue
|
||||||
|
# Extract arg_key/arg_value pairs
|
||||||
|
keys = [k.group(1).strip() for k in re.finditer(r'<arg_key>\s*(.*?)\s*</arg_key>', tc_content, re.DOTALL)]
|
||||||
|
vals = [v.group(1).strip() for v in re.finditer(r'<arg_value>\s*(.*?)\s*</arg_value>', tc_content, re.DOTALL)]
|
||||||
|
if len(keys) != len(vals):
|
||||||
|
continue
|
||||||
|
arguments = {}
|
||||||
|
for k, v in zip(keys, vals):
|
||||||
|
try:
|
||||||
|
v = json.loads(v)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
pass # keep as string
|
||||||
|
arguments[k] = v
|
||||||
|
matches.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": func_name,
|
||||||
|
"arguments": arguments
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
def _parsePythonicToolCalls(answer: str, tool_names: list[str]):
|
||||||
|
"""Parse pythonic-style tool calls used by Llama 4 and similar models.
|
||||||
|
|
||||||
|
Format:
|
||||||
|
[func_name(param1="value1", param2="value2"), func_name2(...)]
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
# Match a bracketed list of function calls
|
||||||
|
bracket_match = re.search(r'\[([^\[\]]+)\]', answer)
|
||||||
|
if not bracket_match:
|
||||||
|
return matches
|
||||||
|
|
||||||
|
inner = bracket_match.group(1)
|
||||||
|
|
||||||
|
# Build pattern for known tool names
|
||||||
|
escaped_names = [re.escape(name) for name in tool_names]
|
||||||
|
name_pattern = '|'.join(escaped_names)
|
||||||
|
|
||||||
|
for call_match in re.finditer(
|
||||||
|
r'(' + name_pattern + r')\(([^)]*)\)',
|
||||||
|
inner
|
||||||
|
):
|
||||||
|
func_name = call_match.group(1)
|
||||||
|
params_str = call_match.group(2).strip()
|
||||||
|
arguments = {}
|
||||||
|
|
||||||
|
if params_str:
|
||||||
|
# Parse key="value" pairs, handling commas inside quoted values
|
||||||
|
for param_match in re.finditer(
|
||||||
|
r'(\w+)\s*=\s*("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\)]+)',
|
||||||
|
params_str
|
||||||
|
):
|
||||||
|
param_name = param_match.group(1)
|
||||||
|
param_value = param_match.group(2).strip()
|
||||||
|
# Strip surrounding quotes
|
||||||
|
if (param_value.startswith('"') and param_value.endswith('"')) or \
|
||||||
|
(param_value.startswith("'") and param_value.endswith("'")):
|
||||||
|
param_value = param_value[1:-1]
|
||||||
|
# Try to parse as JSON for numeric/bool/null values
|
||||||
|
try:
|
||||||
|
param_value = json.loads(param_value)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
pass
|
||||||
|
arguments[param_name] = param_value
|
||||||
|
|
||||||
|
matches.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": func_name,
|
||||||
|
"arguments": arguments
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
def parseToolCall(answer: str, tool_names: list[str]):
|
||||||
|
matches = []
|
||||||
|
|
||||||
|
# abort on very short answers to save computation cycles
|
||||||
|
if len(answer) < 10:
|
||||||
|
return matches
|
||||||
|
|
||||||
|
# Check for DeepSeek-style tool calls (fullwidth Unicode token delimiters)
|
||||||
|
matches = _parseDeepSeekToolCalls(answer, tool_names)
|
||||||
|
if matches:
|
||||||
|
return matches
|
||||||
|
|
||||||
|
# Check for Kimi-K2-style tool calls (pipe-delimited tokens)
|
||||||
|
matches = _parseKimiToolCalls(answer, tool_names)
|
||||||
|
if matches:
|
||||||
|
return matches
|
||||||
|
|
||||||
|
# Check for channel-based tool calls (e.g. GPT-OSS format)
|
||||||
|
matches = _parseChannelToolCalls(answer, tool_names)
|
||||||
|
if matches:
|
||||||
|
return matches
|
||||||
|
|
||||||
|
# Check for MiniMax-style tool calls (invoke/parameter XML tags)
|
||||||
|
matches = _parseMiniMaxToolCalls(answer, tool_names)
|
||||||
|
if matches:
|
||||||
|
return matches
|
||||||
|
|
||||||
|
# Check for GLM-style tool calls (arg_key/arg_value XML pairs)
|
||||||
|
matches = _parseGlmToolCalls(answer, tool_names)
|
||||||
|
if matches:
|
||||||
|
return matches
|
||||||
|
|
||||||
|
# Check for XML-parameter style tool calls (e.g. Qwen3.5 format)
|
||||||
|
matches = _parseXmlParamToolCalls(answer, tool_names)
|
||||||
|
if matches:
|
||||||
|
return matches
|
||||||
|
|
||||||
|
# Check for bare function-name style tool calls (e.g. Mistral format)
|
||||||
|
matches = _parseBareNameToolCalls(answer, tool_names)
|
||||||
|
if matches:
|
||||||
|
return matches
|
||||||
|
|
||||||
|
# Check for pythonic-style tool calls (e.g. Llama 4 format)
|
||||||
|
matches = _parsePythonicToolCalls(answer, tool_names)
|
||||||
|
if matches:
|
||||||
|
return matches
|
||||||
|
|
||||||
|
# Define the regex pattern to find the JSON content wrapped in <function>, <tools>, <tool_call>, and other tags observed from various models
|
||||||
|
patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)</\1>"]
|
||||||
|
|
||||||
|
for pattern in patterns:
|
||||||
|
for match in re.finditer(pattern, answer, re.DOTALL):
|
||||||
|
# print(match.group(2))
|
||||||
|
if match.group(2) is None:
|
||||||
|
continue
|
||||||
|
# remove backtick wraps if present
|
||||||
|
candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip())
|
||||||
|
candidate = re.sub(r"```$", "", candidate.strip())
|
||||||
|
# unwrap inner tags
|
||||||
|
candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL)
|
||||||
|
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
|
||||||
|
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
|
||||||
|
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
|
||||||
|
if not candidate.strip().startswith("["):
|
||||||
|
candidate = "[" + candidate + "]"
|
||||||
|
|
||||||
|
candidates = []
|
||||||
|
try:
|
||||||
|
# parse the candidate JSON into a dictionary
|
||||||
|
candidates = json.loads(candidate)
|
||||||
|
if not isinstance(candidates, list):
|
||||||
|
candidates = [candidates]
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Ignore invalid JSON silently
|
||||||
|
continue
|
||||||
|
|
||||||
|
for candidate_dict in candidates:
|
||||||
|
checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names)
|
||||||
|
if checked_candidate is not None:
|
||||||
|
matches.append(checked_candidate)
|
||||||
|
|
||||||
|
# last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags
|
||||||
|
if len(matches) == 0:
|
||||||
|
try:
|
||||||
|
candidate = answer
|
||||||
|
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
|
||||||
|
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
|
||||||
|
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
|
||||||
|
if not candidate.strip().startswith("["):
|
||||||
|
candidate = "[" + candidate + "]"
|
||||||
|
# parse the candidate JSON into a dictionary
|
||||||
|
candidates = json.loads(candidate)
|
||||||
|
if not isinstance(candidates, list):
|
||||||
|
candidates = [candidates]
|
||||||
|
for candidate_dict in candidates:
|
||||||
|
checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names)
|
||||||
|
if checked_candidate is not None:
|
||||||
|
matches.append(checked_candidate)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Ignore invalid JSON silently
|
||||||
|
pass
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
@ -2,11 +2,8 @@ import concurrent.futures
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from modules.web_search import _validate_url
|
|
||||||
|
|
||||||
|
|
||||||
def download_single(url):
|
def download_single(url):
|
||||||
_validate_url(url)
|
|
||||||
headers = {
|
headers = {
|
||||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -107,7 +107,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||||
|
|
||||||
elif path in ['/api/v1/delete', '/api/delete']:
|
elif path in ['/api/v1/delete', '/api/delete']:
|
||||||
metadata = body.get('metadata')
|
metadata = body.get('metadata')
|
||||||
if metadata is None:
|
if corpus is None:
|
||||||
self._send_412_error("Missing parameter 'metadata'")
|
self._send_412_error("Missing parameter 'metadata'")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,14 +5,12 @@ import requests
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
import extensions.superboogav2.parameters as parameters
|
import extensions.superboogav2.parameters as parameters
|
||||||
from modules.web_search import _validate_url
|
|
||||||
|
|
||||||
from .data_processor import process_and_add_to_collector
|
from .data_processor import process_and_add_to_collector
|
||||||
from .utils import create_metadata_source
|
from .utils import create_metadata_source
|
||||||
|
|
||||||
|
|
||||||
def _download_single(url):
|
def _download_single(url):
|
||||||
_validate_url(url)
|
|
||||||
response = requests.get(url, timeout=5)
|
response = requests.get(url, timeout=5)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return response.content
|
return response.content
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,8 @@ Allows you to enter your inputs in chat mode using your microphone.
|
||||||
To adjust your default settings, you can add the following to your settings.yaml file.
|
To adjust your default settings, you can add the following to your settings.yaml file.
|
||||||
|
|
||||||
```
|
```
|
||||||
whisper_stt-whisper_language: chinese
|
whisper_stt-whipser_language: chinese
|
||||||
whisper_stt-whisper_model: tiny
|
whisper_stt-whipser_model: tiny
|
||||||
whisper_stt-auto_submit: False
|
whisper_stt-auto_submit: False
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,13 +18,13 @@ input_hijack = {
|
||||||
|
|
||||||
# parameters which can be customized in settings.yaml of webui
|
# parameters which can be customized in settings.yaml of webui
|
||||||
params = {
|
params = {
|
||||||
'whisper_language': 'english',
|
'whipser_language': 'english',
|
||||||
'whisper_model': 'small.en',
|
'whipser_model': 'small.en',
|
||||||
'auto_submit': True
|
'auto_submit': True
|
||||||
}
|
}
|
||||||
|
|
||||||
startup_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
startup_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
WHISPERMODEL = whisper.load_model(params['whisper_model'], device=startup_device)
|
WHISPERMODEL = whisper.load_model(params['whipser_model'], device=startup_device)
|
||||||
|
|
||||||
|
|
||||||
def chat_input_modifier(text, visible_text, state):
|
def chat_input_modifier(text, visible_text, state):
|
||||||
|
|
@ -36,7 +36,7 @@ def chat_input_modifier(text, visible_text, state):
|
||||||
return text, visible_text
|
return text, visible_text
|
||||||
|
|
||||||
|
|
||||||
def do_stt(audio, whisper_language):
|
def do_stt(audio, whipser_language):
|
||||||
# use pydub to convert sample_rate and sample_width for whisper input
|
# use pydub to convert sample_rate and sample_width for whisper input
|
||||||
dubaudio = AudioSegment.from_file(io.BytesIO(audio))
|
dubaudio = AudioSegment.from_file(io.BytesIO(audio))
|
||||||
dubaudio = dubaudio.set_channels(1)
|
dubaudio = dubaudio.set_channels(1)
|
||||||
|
|
@ -46,20 +46,20 @@ def do_stt(audio, whisper_language):
|
||||||
# same method to get the array as openai whisper repo used from wav file
|
# same method to get the array as openai whisper repo used from wav file
|
||||||
audio_np = np.frombuffer(dubaudio.raw_data, np.int16).flatten().astype(np.float32) / 32768.0
|
audio_np = np.frombuffer(dubaudio.raw_data, np.int16).flatten().astype(np.float32) / 32768.0
|
||||||
|
|
||||||
if len(whisper_language) == 0:
|
if len(whipser_language) == 0:
|
||||||
result = WHISPERMODEL.transcribe(audio=audio_np)
|
result = WHISPERMODEL.transcribe(audio=audio_np)
|
||||||
else:
|
else:
|
||||||
result = WHISPERMODEL.transcribe(audio=audio_np, language=whisper_language)
|
result = WHISPERMODEL.transcribe(audio=audio_np, language=whipser_language)
|
||||||
return result["text"]
|
return result["text"]
|
||||||
|
|
||||||
|
|
||||||
def auto_transcribe(audio, auto_submit, whisper_language):
|
def auto_transcribe(audio, auto_submit, whipser_language):
|
||||||
if audio is None or audio == "":
|
if audio is None or audio == "":
|
||||||
print("Whisper received no audio data")
|
print("Whisper received no audio data")
|
||||||
return "", ""
|
return "", ""
|
||||||
audio_bytes = base64.b64decode(audio.split(',')[1])
|
audio_bytes = base64.b64decode(audio.split(',')[1])
|
||||||
|
|
||||||
transcription = do_stt(audio_bytes, whisper_language)
|
transcription = do_stt(audio_bytes, whipser_language)
|
||||||
if auto_submit:
|
if auto_submit:
|
||||||
input_hijack.update({"state": True, "value": [transcription, transcription]})
|
input_hijack.update({"state": True, "value": [transcription, transcription]})
|
||||||
return transcription
|
return transcription
|
||||||
|
|
@ -78,7 +78,7 @@ def reload_whispermodel(whisper_model_name: str, whisper_language: str, device:
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
WHISPERMODEL = whisper.load_model(whisper_model_name, device=device)
|
WHISPERMODEL = whisper.load_model(whisper_model_name, device=device)
|
||||||
params.update({"whisper_model": whisper_model_name})
|
params.update({"whipser_model": whisper_model_name})
|
||||||
if ".en" in whisper_model_name:
|
if ".en" in whisper_model_name:
|
||||||
whisper_language = "english"
|
whisper_language = "english"
|
||||||
audio_update = gr.Audio.update(interactive=True)
|
audio_update = gr.Audio.update(interactive=True)
|
||||||
|
|
@ -96,8 +96,8 @@ def ui():
|
||||||
with gr.Accordion("Settings", open=False):
|
with gr.Accordion("Settings", open=False):
|
||||||
auto_submit = gr.Checkbox(label='Submit the transcribed audio automatically', value=params['auto_submit'])
|
auto_submit = gr.Checkbox(label='Submit the transcribed audio automatically', value=params['auto_submit'])
|
||||||
device_dropd = gr.Dropdown(label='Device', value=str(startup_device), choices=["cuda", "cpu", "none"])
|
device_dropd = gr.Dropdown(label='Device', value=str(startup_device), choices=["cuda", "cpu", "none"])
|
||||||
whisper_model_dropd = gr.Dropdown(label='Whisper Model', value=params['whisper_model'], choices=["tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "turbo"])
|
whisper_model_dropd = gr.Dropdown(label='Whisper Model', value=params['whipser_model'], choices=["tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "turbo"])
|
||||||
whisper_language = gr.Dropdown(label='Whisper Language', value=params['whisper_language'], choices=["english", "chinese", "german", "spanish", "russian", "korean", "french", "japanese", "portuguese", "turkish", "polish", "catalan", "dutch", "arabic", "swedish", "italian", "indonesian", "hindi", "finnish", "vietnamese", "hebrew", "ukrainian", "greek", "malay", "czech", "romanian", "danish", "hungarian", "tamil", "norwegian", "thai", "urdu", "croatian", "bulgarian", "lithuanian", "latin", "maori", "malayalam", "welsh", "slovak", "telugu", "persian", "latvian", "bengali", "serbian", "azerbaijani", "slovenian", "kannada", "estonian", "macedonian", "breton", "basque", "icelandic", "armenian", "nepali", "mongolian", "bosnian", "kazakh", "albanian", "swahili", "galician", "marathi", "punjabi", "sinhala", "khmer", "shona", "yoruba", "somali", "afrikaans", "occitan", "georgian", "belarusian", "tajik", "sindhi", "gujarati", "amharic", "yiddish", "lao", "uzbek", "faroese", "haitian creole", "pashto", "turkmen", "nynorsk", "maltese", "sanskrit", "luxembourgish", "myanmar", "tibetan", "tagalog", "malagasy", "assamese", "tatar", "hawaiian", "lingala", "hausa", "bashkir", "javanese", "sundanese"])
|
whisper_language = gr.Dropdown(label='Whisper Language', value=params['whipser_language'], choices=["english", "chinese", "german", "spanish", "russian", "korean", "french", "japanese", "portuguese", "turkish", "polish", "catalan", "dutch", "arabic", "swedish", "italian", "indonesian", "hindi", "finnish", "vietnamese", "hebrew", "ukrainian", "greek", "malay", "czech", "romanian", "danish", "hungarian", "tamil", "norwegian", "thai", "urdu", "croatian", "bulgarian", "lithuanian", "latin", "maori", "malayalam", "welsh", "slovak", "telugu", "persian", "latvian", "bengali", "serbian", "azerbaijani", "slovenian", "kannada", "estonian", "macedonian", "breton", "basque", "icelandic", "armenian", "nepali", "mongolian", "bosnian", "kazakh", "albanian", "swahili", "galician", "marathi", "punjabi", "sinhala", "khmer", "shona", "yoruba", "somali", "afrikaans", "occitan", "georgian", "belarusian", "tajik", "sindhi", "gujarati", "amharic", "yiddish", "lao", "uzbek", "faroese", "haitian creole", "pashto", "turkmen", "nynorsk", "maltese", "sanskrit", "luxembourgish", "myanmar", "tibetan", "tagalog", "malagasy", "assamese", "tatar", "hawaiian", "lingala", "hausa", "bashkir", "javanese", "sundanese"])
|
||||||
|
|
||||||
audio.change(
|
audio.change(
|
||||||
auto_transcribe, [audio, auto_submit, whisper_language], [shared.gradio['textbox']]).then(
|
auto_transcribe, [audio, auto_submit, whisper_language], [shared.gradio['textbox']]).then(
|
||||||
|
|
@ -105,7 +105,7 @@ def ui():
|
||||||
|
|
||||||
device_dropd.input(reload_whispermodel, [whisper_model_dropd, whisper_language, device_dropd], [whisper_model_dropd, whisper_language, device_dropd, audio])
|
device_dropd.input(reload_whispermodel, [whisper_model_dropd, whisper_language, device_dropd], [whisper_model_dropd, whisper_language, device_dropd, audio])
|
||||||
whisper_model_dropd.change(reload_whispermodel, [whisper_model_dropd, whisper_language, device_dropd], [whisper_model_dropd, whisper_language, device_dropd, audio])
|
whisper_model_dropd.change(reload_whispermodel, [whisper_model_dropd, whisper_language, device_dropd], [whisper_model_dropd, whisper_language, device_dropd, audio])
|
||||||
whisper_language.change(lambda x: params.update({"whisper_language": x}), whisper_language, None)
|
whisper_language.change(lambda x: params.update({"whipser_language": x}), whisper_language, None)
|
||||||
auto_submit.change(lambda x: params.update({"auto_submit": x}), auto_submit, None)
|
auto_submit.change(lambda x: params.update({"auto_submit": x}), auto_submit, None)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
function toggleDarkMode() {
|
function toggleDarkMode() {
|
||||||
document.body.classList.toggle("dark");
|
document.body.classList.toggle("dark");
|
||||||
const currentCSS = document.getElementById("highlight-css");
|
var currentCSS = document.getElementById("highlight-css");
|
||||||
if (currentCSS.getAttribute("href") === "file/css/highlightjs/github-dark.min.css") {
|
if (currentCSS.getAttribute("href") === "file/css/highlightjs/github-dark.min.css") {
|
||||||
currentCSS.setAttribute("href", "file/css/highlightjs/github.min.css");
|
currentCSS.setAttribute("href", "file/css/highlightjs/github.min.css");
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -9,10 +9,12 @@ function toggleDarkMode() {
|
||||||
|
|
||||||
// Re-highlight all code blocks once stylesheet loads
|
// Re-highlight all code blocks once stylesheet loads
|
||||||
currentCSS.onload = function() {
|
currentCSS.onload = function() {
|
||||||
// Clear data-highlighted so hljs will re-process with the new theme
|
const messageBodies = document.getElementById("chat").querySelectorAll(".message-body");
|
||||||
document.querySelectorAll("#chat .message-body pre code[data-highlighted]").forEach((codeBlock) => {
|
messageBodies.forEach((messageBody) => {
|
||||||
delete codeBlock.dataset.highlighted;
|
const codeBlocks = messageBody.querySelectorAll("pre code");
|
||||||
|
codeBlocks.forEach((codeBlock) => {
|
||||||
|
hljs.highlightElement(codeBlock);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
doSyntaxHighlighting();
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,35 +1,11 @@
|
||||||
// -------------------------------------------------
|
|
||||||
// Shared helpers
|
|
||||||
// -------------------------------------------------
|
|
||||||
|
|
||||||
function getProfilePictureUrl() {
|
|
||||||
return "/file/user_data/cache/pfp_character.png?time=" + Date.now();
|
|
||||||
}
|
|
||||||
|
|
||||||
const MESSAGE_SELECTOR = ".message, .user-message, .assistant-message";
|
|
||||||
|
|
||||||
function getMessageElement(element) {
|
|
||||||
if (!element) return null;
|
|
||||||
return element.closest(MESSAGE_SELECTOR);
|
|
||||||
}
|
|
||||||
|
|
||||||
function isUserRole(messageElement) {
|
|
||||||
return messageElement.classList.contains("user-message") ||
|
|
||||||
messageElement.querySelector(".text-you") !== null ||
|
|
||||||
messageElement.querySelector(".circle-you") !== null;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Trigger a synthetic 'input' event so Gradio picks up programmatic value changes
|
|
||||||
function dispatchGradioInput(element) {
|
|
||||||
element.dispatchEvent(new Event("input", { bubbles: true }));
|
|
||||||
}
|
|
||||||
|
|
||||||
// -------------------------------------------------
|
// -------------------------------------------------
|
||||||
// Event handlers
|
// Event handlers
|
||||||
// -------------------------------------------------
|
// -------------------------------------------------
|
||||||
|
|
||||||
function copyToClipboard(element) {
|
function copyToClipboard(element) {
|
||||||
const messageElement = getMessageElement(element);
|
if (!element) return;
|
||||||
|
|
||||||
|
const messageElement = element.closest(".message, .user-message, .assistant-message");
|
||||||
if (!messageElement) return;
|
if (!messageElement) return;
|
||||||
|
|
||||||
const rawText = messageElement.getAttribute("data-raw");
|
const rawText = messageElement.getAttribute("data-raw");
|
||||||
|
|
@ -72,7 +48,9 @@ function fallbackCopyToClipboard(text) {
|
||||||
}
|
}
|
||||||
|
|
||||||
function branchHere(element) {
|
function branchHere(element) {
|
||||||
const messageElement = getMessageElement(element);
|
if (!element) return;
|
||||||
|
|
||||||
|
const messageElement = element.closest(".message, .user-message, .assistant-message");
|
||||||
if (!messageElement) return;
|
if (!messageElement) return;
|
||||||
|
|
||||||
const index = messageElement.getAttribute("data-index");
|
const index = messageElement.getAttribute("data-index");
|
||||||
|
|
@ -91,7 +69,11 @@ function branchHere(element) {
|
||||||
}
|
}
|
||||||
|
|
||||||
branchIndexInput.value = index;
|
branchIndexInput.value = index;
|
||||||
dispatchGradioInput(branchIndexInput);
|
|
||||||
|
// Trigger any 'change' or 'input' events Gradio might be listening for
|
||||||
|
const event = new Event("input", { bubbles: true });
|
||||||
|
branchIndexInput.dispatchEvent(event);
|
||||||
|
|
||||||
branchButton.click();
|
branchButton.click();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -100,7 +82,9 @@ function branchHere(element) {
|
||||||
// -------------------------------------------------
|
// -------------------------------------------------
|
||||||
|
|
||||||
function editHere(buttonElement) {
|
function editHere(buttonElement) {
|
||||||
const messageElement = getMessageElement(buttonElement);
|
if (!buttonElement) return;
|
||||||
|
|
||||||
|
const messageElement = buttonElement.closest(".message, .user-message, .assistant-message");
|
||||||
if (!messageElement) return;
|
if (!messageElement) return;
|
||||||
|
|
||||||
const messageBody = messageElement.querySelector(".message-body");
|
const messageBody = messageElement.querySelector(".message-body");
|
||||||
|
|
@ -113,7 +97,12 @@ function editHere(buttonElement) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
startEditing(messageElement, messageBody, isUserRole(messageElement));
|
// Determine role based on message element - handle different chat modes
|
||||||
|
const isUserMessage = messageElement.classList.contains("user-message") ||
|
||||||
|
messageElement.querySelector(".text-you") !== null ||
|
||||||
|
messageElement.querySelector(".circle-you") !== null;
|
||||||
|
|
||||||
|
startEditing(messageElement, messageBody, isUserMessage);
|
||||||
}
|
}
|
||||||
|
|
||||||
function startEditing(messageElement, messageBody, isUserMessage) {
|
function startEditing(messageElement, messageBody, isUserMessage) {
|
||||||
|
|
@ -220,22 +209,30 @@ function submitMessageEdit(index, newText, isUserMessage) {
|
||||||
editTextInput.value = newText;
|
editTextInput.value = newText;
|
||||||
editRoleInput.value = isUserMessage ? "user" : "assistant";
|
editRoleInput.value = isUserMessage ? "user" : "assistant";
|
||||||
|
|
||||||
dispatchGradioInput(editIndexInput);
|
editIndexInput.dispatchEvent(new Event("input", { bubbles: true }));
|
||||||
dispatchGradioInput(editTextInput);
|
editTextInput.dispatchEvent(new Event("input", { bubbles: true }));
|
||||||
dispatchGradioInput(editRoleInput);
|
editRoleInput.dispatchEvent(new Event("input", { bubbles: true }));
|
||||||
|
|
||||||
editButton.click();
|
editButton.click();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
function navigateVersion(element, direction) {
|
function navigateVersion(element, direction) {
|
||||||
const messageElement = getMessageElement(element);
|
if (!element) return;
|
||||||
|
|
||||||
|
const messageElement = element.closest(".message, .user-message, .assistant-message");
|
||||||
if (!messageElement) return;
|
if (!messageElement) return;
|
||||||
|
|
||||||
const index = messageElement.getAttribute("data-index");
|
const index = messageElement.getAttribute("data-index");
|
||||||
if (!index) return;
|
if (!index) return;
|
||||||
|
|
||||||
const role = isUserRole(messageElement) ? "user" : "assistant";
|
// Determine role based on message element classes
|
||||||
|
let role = "assistant"; // Default role
|
||||||
|
if (messageElement.classList.contains("user-message") ||
|
||||||
|
messageElement.querySelector(".text-you") ||
|
||||||
|
messageElement.querySelector(".circle-you")) {
|
||||||
|
role = "user";
|
||||||
|
}
|
||||||
|
|
||||||
const indexInput = document.getElementById("Navigate-message-index")?.querySelector("input");
|
const indexInput = document.getElementById("Navigate-message-index")?.querySelector("input");
|
||||||
const directionInput = document.getElementById("Navigate-direction")?.querySelector("textarea");
|
const directionInput = document.getElementById("Navigate-direction")?.querySelector("textarea");
|
||||||
|
|
@ -251,9 +248,11 @@ function navigateVersion(element, direction) {
|
||||||
directionInput.value = direction;
|
directionInput.value = direction;
|
||||||
roleInput.value = role;
|
roleInput.value = role;
|
||||||
|
|
||||||
dispatchGradioInput(indexInput);
|
// Trigger 'input' events for Gradio to pick up changes
|
||||||
dispatchGradioInput(directionInput);
|
const event = new Event("input", { bubbles: true });
|
||||||
dispatchGradioInput(roleInput);
|
indexInput.dispatchEvent(event);
|
||||||
|
directionInput.dispatchEvent(event);
|
||||||
|
roleInput.dispatchEvent(event);
|
||||||
|
|
||||||
navigateButton.click();
|
navigateButton.click();
|
||||||
}
|
}
|
||||||
|
|
@ -270,51 +269,9 @@ function removeLastClick() {
|
||||||
document.getElementById("Remove-last").click();
|
document.getElementById("Remove-last").click();
|
||||||
}
|
}
|
||||||
|
|
||||||
function autoScrollToBottom() {
|
|
||||||
if (!window.isScrolled) {
|
|
||||||
const chatParent = document.getElementById("chat")?.parentNode?.parentNode?.parentNode;
|
|
||||||
if (chatParent) {
|
|
||||||
const maxScroll = chatParent.scrollHeight - chatParent.clientHeight;
|
|
||||||
if (maxScroll > 0 && chatParent.scrollTop < maxScroll - 1) {
|
|
||||||
chatParent.scrollTop = maxScroll;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function updateInstructPadding() {
|
|
||||||
const chatElement = document.getElementById("chat");
|
|
||||||
if (chatElement && chatElement.getAttribute("data-mode") === "instruct") {
|
|
||||||
const messagesContainer = chatElement.querySelector(".messages");
|
|
||||||
const lastChild = messagesContainer?.lastElementChild;
|
|
||||||
const prevSibling = lastChild?.previousElementSibling;
|
|
||||||
if (lastChild && prevSibling && chatElement.offsetHeight > 0) {
|
|
||||||
let bufferHeight = Math.max(0, Math.max(window.innerHeight - 128 - 84, window.innerHeight - prevSibling.offsetHeight - 84) - lastChild.offsetHeight);
|
|
||||||
if (window.innerWidth <= 924) {
|
|
||||||
bufferHeight = Math.max(0, bufferHeight - 32);
|
|
||||||
}
|
|
||||||
messagesContainer.style.paddingBottom = `${bufferHeight}px`;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let pendingMorphdomData = null;
|
|
||||||
let morphdomRafId = null;
|
|
||||||
|
|
||||||
function handleMorphdomUpdate(data) {
|
function handleMorphdomUpdate(data) {
|
||||||
pendingMorphdomData = data;
|
|
||||||
if (!morphdomRafId) {
|
|
||||||
morphdomRafId = requestAnimationFrame(() => {
|
|
||||||
morphdomRafId = null;
|
|
||||||
applyMorphdomUpdate(pendingMorphdomData);
|
|
||||||
pendingMorphdomData = null;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function applyMorphdomUpdate(data) {
|
|
||||||
// Determine target element and use it as query scope
|
// Determine target element and use it as query scope
|
||||||
let target_element, target_html;
|
var target_element, target_html;
|
||||||
if (data.last_message_only) {
|
if (data.last_message_only) {
|
||||||
const childNodes = document.getElementsByClassName("messages")[0].childNodes;
|
const childNodes = document.getElementsByClassName("messages")[0].childNodes;
|
||||||
target_element = childNodes[childNodes.length - 1];
|
target_element = childNodes[childNodes.length - 1];
|
||||||
|
|
@ -326,22 +283,28 @@ function applyMorphdomUpdate(data) {
|
||||||
|
|
||||||
const queryScope = target_element;
|
const queryScope = target_element;
|
||||||
|
|
||||||
// Track open blocks and store their scroll positions
|
// Track open blocks
|
||||||
const openBlocks = new Set();
|
const openBlocks = new Set();
|
||||||
const scrollPositions = {};
|
|
||||||
queryScope.querySelectorAll(".thinking-block").forEach(block => {
|
queryScope.querySelectorAll(".thinking-block").forEach(block => {
|
||||||
const blockId = block.getAttribute("data-block-id");
|
const blockId = block.getAttribute("data-block-id");
|
||||||
|
// If block exists and is open, add to open set
|
||||||
if (blockId && block.hasAttribute("open")) {
|
if (blockId && block.hasAttribute("open")) {
|
||||||
openBlocks.add(blockId);
|
openBlocks.add(blockId);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Store scroll positions for any open blocks
|
||||||
|
const scrollPositions = {};
|
||||||
|
queryScope.querySelectorAll(".thinking-block[open]").forEach(block => {
|
||||||
const content = block.querySelector(".thinking-content");
|
const content = block.querySelector(".thinking-content");
|
||||||
if (content) {
|
const blockId = block.getAttribute("data-block-id");
|
||||||
|
if (content && blockId) {
|
||||||
const isAtBottom = Math.abs((content.scrollHeight - content.scrollTop) - content.clientHeight) < 5;
|
const isAtBottom = Math.abs((content.scrollHeight - content.scrollTop) - content.clientHeight) < 5;
|
||||||
scrollPositions[blockId] = {
|
scrollPositions[blockId] = {
|
||||||
position: content.scrollTop,
|
position: content.scrollTop,
|
||||||
isAtBottom: isAtBottom
|
isAtBottom: isAtBottom
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
morphdom(
|
morphdom(
|
||||||
|
|
@ -350,8 +313,8 @@ function applyMorphdomUpdate(data) {
|
||||||
{
|
{
|
||||||
onBeforeElUpdated: function(fromEl, toEl) {
|
onBeforeElUpdated: function(fromEl, toEl) {
|
||||||
// Preserve code highlighting
|
// Preserve code highlighting
|
||||||
if (fromEl.tagName === "PRE") {
|
if (fromEl.tagName === "PRE" && fromEl.querySelector("code[data-highlighted]")) {
|
||||||
const fromCode = fromEl.querySelector("code[data-highlighted]");
|
const fromCode = fromEl.querySelector("code");
|
||||||
const toCode = toEl.querySelector("code");
|
const toCode = toEl.querySelector("code");
|
||||||
|
|
||||||
if (fromCode && toCode && fromCode.textContent === toCode.textContent) {
|
if (fromCode && toCode && fromCode.textContent === toCode.textContent) {
|
||||||
|
|
@ -396,23 +359,10 @@ function applyMorphdomUpdate(data) {
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
// Syntax highlighting and LaTeX
|
|
||||||
if (window.doSyntaxHighlighting) {
|
|
||||||
window.doSyntaxHighlighting();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Auto-scroll runs both before and after padding update.
|
|
||||||
// Before: so content growth isn't hidden by padding absorption.
|
|
||||||
// After: so padding-added space is also scrolled into view.
|
|
||||||
autoScrollToBottom();
|
|
||||||
updateInstructPadding();
|
|
||||||
autoScrollToBottom();
|
|
||||||
|
|
||||||
// Add toggle listeners for new blocks
|
// Add toggle listeners for new blocks
|
||||||
queryScope.querySelectorAll(".thinking-block").forEach(block => {
|
queryScope.querySelectorAll(".thinking-block").forEach(block => {
|
||||||
if (!block._hasToggleListener) {
|
if (!block._hasToggleListener) {
|
||||||
block.addEventListener("toggle", function(e) {
|
block.addEventListener("toggle", function(e) {
|
||||||
const wasScrolled = window.isScrolled;
|
|
||||||
if (this.open) {
|
if (this.open) {
|
||||||
const content = this.querySelector(".thinking-content");
|
const content = this.querySelector(".thinking-content");
|
||||||
if (content) {
|
if (content) {
|
||||||
|
|
@ -421,12 +371,6 @@ function applyMorphdomUpdate(data) {
|
||||||
}, 0);
|
}, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
autoScrollToBottom();
|
|
||||||
updateInstructPadding();
|
|
||||||
autoScrollToBottom();
|
|
||||||
// Restore scroll state so the browser's layout adjustment
|
|
||||||
// from the toggle doesn't disable auto-scroll
|
|
||||||
window.isScrolled = wasScrolled;
|
|
||||||
});
|
});
|
||||||
block._hasToggleListener = true;
|
block._hasToggleListener = true;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
322
js/main.js
322
js/main.js
|
|
@ -2,13 +2,6 @@
|
||||||
// Main
|
// Main
|
||||||
// ------------------------------------------------
|
// ------------------------------------------------
|
||||||
|
|
||||||
// Sync highlight.js theme with the actual Gradio theme
|
|
||||||
var defined_hljs_css = document.body.classList.contains("dark") ? "file/css/highlightjs/github-dark.min.css" : "file/css/highlightjs/github.min.css";
|
|
||||||
var hljsCssElement = document.getElementById("highlight-css");
|
|
||||||
if (hljsCssElement.getAttribute("href") !== defined_hljs_css) {
|
|
||||||
hljsCssElement.setAttribute("href", defined_hljs_css);
|
|
||||||
}
|
|
||||||
|
|
||||||
let main_parent = document.getElementById("chat-tab").parentNode;
|
let main_parent = document.getElementById("chat-tab").parentNode;
|
||||||
let extensions = document.getElementById("extensions");
|
let extensions = document.getElementById("extensions");
|
||||||
|
|
||||||
|
|
@ -50,18 +43,21 @@ document.querySelector(".header_bar").addEventListener("click", function(event)
|
||||||
//------------------------------------------------
|
//------------------------------------------------
|
||||||
|
|
||||||
// --- Helper functions --- //
|
// --- Helper functions --- //
|
||||||
function isModifiedKeyboardEvent(event) {
|
function isModifiedKeyboardEvent() {
|
||||||
return event instanceof KeyboardEvent &&
|
return (event instanceof KeyboardEvent &&
|
||||||
(event.shiftKey || event.ctrlKey || event.altKey || event.metaKey);
|
event.shiftKey ||
|
||||||
|
event.ctrlKey ||
|
||||||
|
event.altKey ||
|
||||||
|
event.metaKey);
|
||||||
}
|
}
|
||||||
|
|
||||||
function isFocusedOnEditableTextbox(event) {
|
function isFocusedOnEditableTextbox() {
|
||||||
if (event.target.tagName === "INPUT" || event.target.tagName === "TEXTAREA") {
|
if (event.target.tagName === "INPUT" || event.target.tagName === "TEXTAREA") {
|
||||||
return !!event.target.value;
|
return !!event.target.value;
|
||||||
}
|
}
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let previousTabId = "chat-tab-button";
|
||||||
document.addEventListener("keydown", function(event) {
|
document.addEventListener("keydown", function(event) {
|
||||||
// Stop generation on Esc pressed
|
// Stop generation on Esc pressed
|
||||||
if (event.key === "Escape") {
|
if (event.key === "Escape") {
|
||||||
|
|
@ -115,14 +111,14 @@ document.addEventListener("keydown", function(event) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Simple version navigation --- //
|
// --- Simple version navigation --- //
|
||||||
if (!isFocusedOnEditableTextbox(event)) {
|
if (!isFocusedOnEditableTextbox()) {
|
||||||
// Version navigation on Arrow keys (horizontal)
|
// Version navigation on Arrow keys (horizontal)
|
||||||
if (!isModifiedKeyboardEvent(event) && event.key === "ArrowLeft") {
|
if (!isModifiedKeyboardEvent() && event.key === "ArrowLeft") {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
navigateLastAssistantMessage("left");
|
navigateLastAssistantMessage("left");
|
||||||
}
|
}
|
||||||
|
|
||||||
else if (!isModifiedKeyboardEvent(event) && event.key === "ArrowRight") {
|
else if (!isModifiedKeyboardEvent() && event.key === "ArrowRight") {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
if (!navigateLastAssistantMessage("right")) {
|
if (!navigateLastAssistantMessage("right")) {
|
||||||
// If can't navigate right (last version), regenerate
|
// If can't navigate right (last version), regenerate
|
||||||
|
|
@ -149,26 +145,21 @@ targetElement.classList.add("pretty_scrollbar");
|
||||||
targetElement.classList.add("chat-parent");
|
targetElement.classList.add("chat-parent");
|
||||||
window.isScrolled = false;
|
window.isScrolled = false;
|
||||||
let scrollTimeout;
|
let scrollTimeout;
|
||||||
let lastScrollTop = 0;
|
|
||||||
let lastScrollHeight = 0;
|
|
||||||
let lastClientHeight = 0;
|
|
||||||
|
|
||||||
targetElement.addEventListener("scroll", function() {
|
targetElement.addEventListener("scroll", function() {
|
||||||
let diff = targetElement.scrollHeight - targetElement.clientHeight;
|
let diff = targetElement.scrollHeight - targetElement.clientHeight;
|
||||||
let isAtBottomNow = Math.abs(targetElement.scrollTop - diff) <= 10 || diff <= 0;
|
let isAtBottomNow = Math.abs(targetElement.scrollTop - diff) <= 10 || diff == 0;
|
||||||
|
|
||||||
|
// Add scrolling class to disable hover effects
|
||||||
if (window.isScrolled || !isAtBottomNow) {
|
if (window.isScrolled || !isAtBottomNow) {
|
||||||
targetElement.classList.add("scrolling"); // Disables hover effects during scroll
|
targetElement.classList.add("scrolling");
|
||||||
}
|
}
|
||||||
|
|
||||||
if(isAtBottomNow) {
|
if(isAtBottomNow) {
|
||||||
window.isScrolled = false;
|
window.isScrolled = false;
|
||||||
} else if (targetElement.scrollTop < lastScrollTop && targetElement.scrollHeight >= lastScrollHeight && targetElement.clientHeight <= lastClientHeight) {
|
} else {
|
||||||
window.isScrolled = true;
|
window.isScrolled = true;
|
||||||
}
|
}
|
||||||
lastScrollTop = targetElement.scrollTop;
|
|
||||||
lastScrollHeight = targetElement.scrollHeight;
|
|
||||||
lastClientHeight = targetElement.clientHeight;
|
|
||||||
|
|
||||||
// Clear previous timeout and set new one
|
// Clear previous timeout and set new one
|
||||||
clearTimeout(scrollTimeout);
|
clearTimeout(scrollTimeout);
|
||||||
|
|
@ -179,28 +170,65 @@ targetElement.addEventListener("scroll", function() {
|
||||||
});
|
});
|
||||||
|
|
||||||
// Create a MutationObserver instance
|
// Create a MutationObserver instance
|
||||||
const observer = new MutationObserver(function() {
|
const observer = new MutationObserver(function(mutations) {
|
||||||
|
// Check if this is just the scrolling class being toggled
|
||||||
|
const isScrollingClassOnly = mutations.every(mutation =>
|
||||||
|
mutation.type === "attributes" &&
|
||||||
|
mutation.attributeName === "class" &&
|
||||||
|
mutation.target === targetElement
|
||||||
|
);
|
||||||
|
|
||||||
if (targetElement.classList.contains("_generating")) {
|
if (targetElement.classList.contains("_generating")) {
|
||||||
typing.parentNode.classList.add("visible-dots");
|
typing.parentNode.classList.add("visible-dots");
|
||||||
document.getElementById("stop").style.display = "flex";
|
document.getElementById("stop").style.display = "flex";
|
||||||
document.getElementById("Generate").style.display = "none";
|
document.getElementById("Generate").style.display = "none";
|
||||||
// If the user is near the bottom, ensure auto-scroll is enabled
|
|
||||||
// for the new reply. This catches cases where isScrolled was
|
|
||||||
// incorrectly set to true by layout shifts during page load, etc.
|
|
||||||
const diff = targetElement.scrollHeight - targetElement.clientHeight;
|
|
||||||
if (Math.abs(targetElement.scrollTop - diff) <= 10 || diff <= 0) {
|
|
||||||
window.isScrolled = false;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
typing.parentNode.classList.remove("visible-dots");
|
typing.parentNode.classList.remove("visible-dots");
|
||||||
document.getElementById("stop").style.display = "none";
|
document.getElementById("stop").style.display = "none";
|
||||||
document.getElementById("Generate").style.display = "flex";
|
document.getElementById("Generate").style.display = "flex";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
doSyntaxHighlighting();
|
||||||
|
|
||||||
|
if (!window.isScrolled && !isScrollingClassOnly) {
|
||||||
|
const maxScroll = targetElement.scrollHeight - targetElement.clientHeight;
|
||||||
|
if (maxScroll > 0 && targetElement.scrollTop < maxScroll - 1) {
|
||||||
|
targetElement.scrollTop = maxScroll;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const chatElement = document.getElementById("chat");
|
||||||
|
if (chatElement && chatElement.getAttribute("data-mode") === "instruct") {
|
||||||
|
const messagesContainer = chatElement.querySelector(".messages");
|
||||||
|
const lastChild = messagesContainer?.lastElementChild;
|
||||||
|
const prevSibling = lastChild?.previousElementSibling;
|
||||||
|
if (lastChild && prevSibling) {
|
||||||
|
// Add padding to the messages container to create room for the last message.
|
||||||
|
// The purpose of this is to avoid constant scrolling during streaming in
|
||||||
|
// instruct mode.
|
||||||
|
let bufferHeight = Math.max(0, Math.max(window.innerHeight - 128 - 84, window.innerHeight - prevSibling.offsetHeight - 84) - lastChild.offsetHeight);
|
||||||
|
|
||||||
|
// Subtract header height when screen width is <= 924px
|
||||||
|
if (window.innerWidth <= 924) {
|
||||||
|
bufferHeight = Math.max(0, bufferHeight - 32);
|
||||||
|
}
|
||||||
|
|
||||||
|
messagesContainer.style.paddingBottom = `${bufferHeight}px`;
|
||||||
|
}
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// Only watch for attribute changes on targetElement (e.g. _generating class)
|
// Configure the observer to watch for changes in the subtree and attributes
|
||||||
|
const config = {
|
||||||
|
childList: true,
|
||||||
|
subtree: true,
|
||||||
|
characterData: true,
|
||||||
|
attributeOldValue: true,
|
||||||
|
characterDataOldValue: true
|
||||||
|
};
|
||||||
|
|
||||||
// Start observing the target element
|
// Start observing the target element
|
||||||
observer.observe(targetElement, { attributes: true });
|
observer.observe(targetElement, config);
|
||||||
|
|
||||||
//------------------------------------------------
|
//------------------------------------------------
|
||||||
// Handle syntax highlighting / LaTeX
|
// Handle syntax highlighting / LaTeX
|
||||||
|
|
@ -215,13 +243,16 @@ function isElementVisibleOnScreen(element) {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
window.doSyntaxHighlighting = function() {
|
function doSyntaxHighlighting() {
|
||||||
const messageBodies = document.getElementById("chat").querySelectorAll(".message-body");
|
const messageBodies = document.getElementById("chat").querySelectorAll(".message-body");
|
||||||
|
|
||||||
if (messageBodies.length > 0) {
|
if (messageBodies.length > 0) {
|
||||||
|
observer.disconnect();
|
||||||
|
|
||||||
|
try {
|
||||||
let hasSeenVisible = false;
|
let hasSeenVisible = false;
|
||||||
|
|
||||||
// Go from last message to first so we can early-exit once past visible area
|
// Go from last message to first
|
||||||
for (let i = messageBodies.length - 1; i >= 0; i--) {
|
for (let i = messageBodies.length - 1; i >= 0; i--) {
|
||||||
const messageBody = messageBodies[i];
|
const messageBody = messageBodies[i];
|
||||||
|
|
||||||
|
|
@ -236,8 +267,8 @@ window.doSyntaxHighlighting = function() {
|
||||||
codeBlock.classList.add("pretty_scrollbar");
|
codeBlock.classList.add("pretty_scrollbar");
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Only render math in visible elements
|
||||||
const mathContainers = messageBody.querySelectorAll("p, span, li, td, th, h1, h2, h3, h4, h5, h6, blockquote, figcaption, caption, dd, dt");
|
const mathContainers = messageBody.querySelectorAll("p, span, li, td, th, h1, h2, h3, h4, h5, h6, blockquote, figcaption, caption, dd, dt");
|
||||||
// Only render math in individually visible containers (the outer check is on the message body)
|
|
||||||
mathContainers.forEach(container => {
|
mathContainers.forEach(container => {
|
||||||
if (isElementVisibleOnScreen(container)) {
|
if (isElementVisibleOnScreen(container)) {
|
||||||
renderMathInElement(container, {
|
renderMathInElement(container, {
|
||||||
|
|
@ -256,48 +287,33 @@ window.doSyntaxHighlighting = function() {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} finally {
|
||||||
|
observer.observe(targetElement, config);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const doSyntaxHighlighting = window.doSyntaxHighlighting;
|
|
||||||
|
|
||||||
//------------------------------------------------
|
//------------------------------------------------
|
||||||
// Add some scrollbars
|
// Add some scrollbars
|
||||||
//------------------------------------------------
|
//------------------------------------------------
|
||||||
const scrollbarElements = document.querySelectorAll(".add_scrollbar textarea, .add_scrollbar .drag-drop-list");
|
const textareaElements = document.querySelectorAll(".add_scrollbar textarea");
|
||||||
for(let i = 0; i < scrollbarElements.length; i++) {
|
for(i = 0; i < textareaElements.length; i++) {
|
||||||
scrollbarElements[i].classList.remove("scroll-hide");
|
textareaElements[i].classList.remove("scroll-hide");
|
||||||
scrollbarElements[i].classList.add("pretty_scrollbar");
|
textareaElements[i].classList.add("pretty_scrollbar");
|
||||||
scrollbarElements[i].style.resize = "none";
|
textareaElements[i].style.resize = "none";
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
//------------------------------------------------
|
|
||||||
// Tools: inject "Refresh list" link into the label
|
|
||||||
//------------------------------------------------
|
|
||||||
const toolsTitle = document.querySelector("#tools-group > [data-testid='block-info']");
|
|
||||||
const toolsInfo = toolsTitle ? toolsTitle.nextElementSibling : null;
|
|
||||||
if (toolsInfo) {
|
|
||||||
const refreshLink = document.createElement("span");
|
|
||||||
refreshLink.textContent = " [Refresh list]";
|
|
||||||
refreshLink.className = "tools-refresh-link";
|
|
||||||
refreshLink.addEventListener("click", function(e) {
|
|
||||||
e.preventDefault();
|
|
||||||
document.querySelector("#tools-refresh-btn").click();
|
|
||||||
});
|
|
||||||
toolsInfo.appendChild(refreshLink);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//------------------------------------------------
|
//------------------------------------------------
|
||||||
// Remove some backgrounds
|
// Remove some backgrounds
|
||||||
//------------------------------------------------
|
//------------------------------------------------
|
||||||
const noBackgroundelements = document.querySelectorAll(".no-background");
|
const noBackgroundelements = document.querySelectorAll(".no-background");
|
||||||
for(let i = 0; i < noBackgroundelements.length; i++) {
|
for(i = 0; i < noBackgroundelements.length; i++) {
|
||||||
noBackgroundelements[i].parentNode.style.border = "none";
|
noBackgroundelements[i].parentNode.style.border = "none";
|
||||||
noBackgroundelements[i].parentNode.parentNode.parentNode.style.alignItems = "center";
|
noBackgroundelements[i].parentNode.parentNode.parentNode.style.alignItems = "center";
|
||||||
}
|
}
|
||||||
|
|
||||||
const slimDropdownElements = document.querySelectorAll(".slim-dropdown");
|
const slimDropdownElements = document.querySelectorAll(".slim-dropdown");
|
||||||
for (let i = 0; i < slimDropdownElements.length; i++) {
|
for (i = 0; i < slimDropdownElements.length; i++) {
|
||||||
const parentNode = slimDropdownElements[i].parentNode;
|
const parentNode = slimDropdownElements[i].parentNode;
|
||||||
parentNode.style.background = "transparent";
|
parentNode.style.background = "transparent";
|
||||||
parentNode.style.border = "0";
|
parentNode.style.border = "0";
|
||||||
|
|
@ -309,19 +325,18 @@ for (let i = 0; i < slimDropdownElements.length; i++) {
|
||||||
// https://github.com/SillyTavern/SillyTavern/blob/6c8bd06308c69d51e2eb174541792a870a83d2d6/public/script.js
|
// https://github.com/SillyTavern/SillyTavern/blob/6c8bd06308c69d51e2eb174541792a870a83d2d6/public/script.js
|
||||||
//------------------------------------------------
|
//------------------------------------------------
|
||||||
var buttonsInChat = document.querySelectorAll("#chat-tab #chat-buttons button, #chat-tab #chat-buttons #show-controls");
|
var buttonsInChat = document.querySelectorAll("#chat-tab #chat-buttons button, #chat-tab #chat-buttons #show-controls");
|
||||||
var hoverContainer = document.getElementById("gr-hover-container");
|
|
||||||
var button = document.getElementById("hover-element-button");
|
var button = document.getElementById("hover-element-button");
|
||||||
var menu = document.getElementById("hover-menu");
|
var menu = document.getElementById("hover-menu");
|
||||||
var istouchscreen = (navigator.maxTouchPoints > 0) || "ontouchstart" in document.documentElement;
|
var istouchscreen = (navigator.maxTouchPoints > 0) || "ontouchstart" in document.documentElement;
|
||||||
|
|
||||||
function showMenu() {
|
function showMenu() {
|
||||||
menu.style.display = "flex";
|
menu.style.display = "flex"; // Show the menu
|
||||||
}
|
}
|
||||||
|
|
||||||
function hideMenu() {
|
function hideMenu() {
|
||||||
menu.style.display = "none";
|
menu.style.display = "none"; // Hide the menu
|
||||||
if (!istouchscreen) {
|
if (!istouchscreen) {
|
||||||
document.querySelector("#chat-input textarea").focus();
|
document.querySelector("#chat-input textarea").focus(); // Focus on the chat input
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -330,6 +345,7 @@ if (buttonsInChat.length > 0) {
|
||||||
const thisButton = buttonsInChat[i];
|
const thisButton = buttonsInChat[i];
|
||||||
menu.appendChild(thisButton);
|
menu.appendChild(thisButton);
|
||||||
|
|
||||||
|
// Only apply transformations to button elements
|
||||||
if (thisButton.tagName.toLowerCase() === "button") {
|
if (thisButton.tagName.toLowerCase() === "button") {
|
||||||
thisButton.addEventListener("click", () => {
|
thisButton.addEventListener("click", () => {
|
||||||
hideMenu();
|
hideMenu();
|
||||||
|
|
@ -339,6 +355,7 @@ if (buttonsInChat.length > 0) {
|
||||||
const matches = buttonText.match(/(\(.*?\))/);
|
const matches = buttonText.match(/(\(.*?\))/);
|
||||||
|
|
||||||
if (matches && matches.length > 1) {
|
if (matches && matches.length > 1) {
|
||||||
|
// Apply the transparent-substring class to the matched substring
|
||||||
const substring = matches[1];
|
const substring = matches[1];
|
||||||
const newText = buttonText.replace(substring, ` <span class="transparent-substring">${substring.slice(1, -1)}</span>`);
|
const newText = buttonText.replace(substring, ` <span class="transparent-substring">${substring.slice(1, -1)}</span>`);
|
||||||
thisButton.innerHTML = newText;
|
thisButton.innerHTML = newText;
|
||||||
|
|
@ -347,19 +364,16 @@ if (buttonsInChat.length > 0) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var menuInteracting = false;
|
function isMouseOverButtonOrMenu() {
|
||||||
|
return menu.matches(":hover") || button.matches(":hover");
|
||||||
|
}
|
||||||
|
|
||||||
hoverContainer.addEventListener("mouseenter", function () {
|
button.addEventListener("mouseenter", function () {
|
||||||
if (!istouchscreen) {
|
if (!istouchscreen) {
|
||||||
showMenu();
|
showMenu();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
hoverContainer.addEventListener("mousedown", function () {
|
|
||||||
menuInteracting = true;
|
|
||||||
setTimeout(function () { menuInteracting = false; }, 300);
|
|
||||||
});
|
|
||||||
|
|
||||||
button.addEventListener("click", function () {
|
button.addEventListener("click", function () {
|
||||||
if (menu.style.display === "flex") {
|
if (menu.style.display === "flex") {
|
||||||
hideMenu();
|
hideMenu();
|
||||||
|
|
@ -369,26 +383,36 @@ button.addEventListener("click", function () {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
hoverContainer.addEventListener("mouseleave", function () {
|
// Add event listener for mouseleave on the button
|
||||||
if (!istouchscreen) {
|
button.addEventListener("mouseleave", function () {
|
||||||
|
// Delay to prevent menu hiding when the mouse leaves the button into the menu
|
||||||
setTimeout(function () {
|
setTimeout(function () {
|
||||||
if (!hoverContainer.matches(":hover") && !menu.matches(":hover")) {
|
if (!isMouseOverButtonOrMenu()) {
|
||||||
hideMenu();
|
hideMenu();
|
||||||
}
|
}
|
||||||
}, 50);
|
}, 100);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add event listener for mouseleave on the menu
|
||||||
|
menu.addEventListener("mouseleave", function () {
|
||||||
|
// Delay to prevent menu hide when the mouse leaves the menu into the button
|
||||||
|
setTimeout(function () {
|
||||||
|
if (!isMouseOverButtonOrMenu()) {
|
||||||
|
hideMenu();
|
||||||
}
|
}
|
||||||
|
}, 100);
|
||||||
});
|
});
|
||||||
|
|
||||||
// Add event listener for click anywhere in the document
|
// Add event listener for click anywhere in the document
|
||||||
document.addEventListener("click", function (event) {
|
document.addEventListener("click", function (event) {
|
||||||
|
const target = event.target;
|
||||||
|
|
||||||
// Check if the click is outside the button/menu and the menu is visible
|
// Check if the click is outside the button/menu and the menu is visible
|
||||||
if (!menuInteracting && !event.target.closest("#gr-hover-container") && menu.style.display === "flex") {
|
if (!isMouseOverButtonOrMenu() && menu.style.display === "flex") {
|
||||||
hideMenu();
|
hideMenu();
|
||||||
}
|
}
|
||||||
|
|
||||||
const target = event.target;
|
if (event.target.classList.contains("pfp_character")) {
|
||||||
|
|
||||||
if (target.classList.contains("pfp_character")) {
|
|
||||||
toggleBigPicture();
|
toggleBigPicture();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -418,19 +442,27 @@ document.getElementById("chat-input-row").classList.add("chat-input-positioned")
|
||||||
//------------------------------------------------
|
//------------------------------------------------
|
||||||
const chatTextArea = document.getElementById("chat-input").querySelector("textarea");
|
const chatTextArea = document.getElementById("chat-input").querySelector("textarea");
|
||||||
|
|
||||||
function focusOnVisible(element) {
|
function respondToChatInputVisibility(element, callback) {
|
||||||
var observer = new IntersectionObserver((entries) => {
|
var options = {
|
||||||
|
root: document.documentElement,
|
||||||
|
};
|
||||||
|
|
||||||
|
var observer = new IntersectionObserver((entries, observer) => {
|
||||||
entries.forEach(entry => {
|
entries.forEach(entry => {
|
||||||
if (entry.intersectionRatio > 0) {
|
callback(entry.intersectionRatio > 0);
|
||||||
element.focus();
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
}, { root: document.documentElement });
|
}, options);
|
||||||
|
|
||||||
observer.observe(element);
|
observer.observe(element);
|
||||||
}
|
}
|
||||||
|
|
||||||
focusOnVisible(chatTextArea);
|
function handleChatInputVisibilityChange(isVisible) {
|
||||||
|
if (isVisible) {
|
||||||
|
chatTextArea.focus();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
respondToChatInputVisibility(chatTextArea, handleChatInputVisibilityChange);
|
||||||
|
|
||||||
//------------------------------------------------
|
//------------------------------------------------
|
||||||
// Show enlarged character picture when the profile
|
// Show enlarged character picture when the profile
|
||||||
|
|
@ -440,7 +472,8 @@ let bigPictureVisible = false;
|
||||||
|
|
||||||
function addBigPicture() {
|
function addBigPicture() {
|
||||||
var imgElement = document.createElement("img");
|
var imgElement = document.createElement("img");
|
||||||
imgElement.src = getProfilePictureUrl();
|
var timestamp = new Date().getTime();
|
||||||
|
imgElement.src = "/file/user_data/cache/pfp_character.png?time=" + timestamp;
|
||||||
imgElement.classList.add("bigProfilePicture");
|
imgElement.classList.add("bigProfilePicture");
|
||||||
imgElement.addEventListener("load", function () {
|
imgElement.addEventListener("load", function () {
|
||||||
this.style.visibility = "visible";
|
this.style.visibility = "visible";
|
||||||
|
|
@ -454,8 +487,9 @@ function addBigPicture() {
|
||||||
}
|
}
|
||||||
|
|
||||||
function deleteBigPicture() {
|
function deleteBigPicture() {
|
||||||
document.querySelectorAll(".bigProfilePicture").forEach(function (element) {
|
var bigProfilePictures = document.querySelectorAll(".bigProfilePicture");
|
||||||
element.remove();
|
bigProfilePictures.forEach(function (element) {
|
||||||
|
element.parentNode.removeChild(element);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -469,11 +503,44 @@ function toggleBigPicture() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//------------------------------------------------
|
||||||
|
// Handle the chat input box growth
|
||||||
|
//------------------------------------------------
|
||||||
|
|
||||||
|
// Cache DOM elements
|
||||||
|
const chatContainer = document.getElementById("chat").parentNode.parentNode.parentNode;
|
||||||
|
const chatInput = document.querySelector("#chat-input textarea");
|
||||||
|
|
||||||
|
// Variables to store current dimensions
|
||||||
|
let currentChatInputHeight = chatInput.clientHeight;
|
||||||
|
|
||||||
//------------------------------------------------
|
//------------------------------------------------
|
||||||
// Focus on the rename text area when it becomes visible
|
// Focus on the rename text area when it becomes visible
|
||||||
//------------------------------------------------
|
//------------------------------------------------
|
||||||
const renameTextArea = document.getElementById("rename-row").querySelector("textarea");
|
const renameTextArea = document.getElementById("rename-row").querySelector("textarea");
|
||||||
focusOnVisible(renameTextArea);
|
|
||||||
|
function respondToRenameVisibility(element, callback) {
|
||||||
|
var options = {
|
||||||
|
root: document.documentElement,
|
||||||
|
};
|
||||||
|
|
||||||
|
var observer = new IntersectionObserver((entries, observer) => {
|
||||||
|
entries.forEach(entry => {
|
||||||
|
callback(entry.intersectionRatio > 0);
|
||||||
|
});
|
||||||
|
}, options);
|
||||||
|
|
||||||
|
observer.observe(element);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
function handleVisibilityChange(isVisible) {
|
||||||
|
if (isVisible) {
|
||||||
|
renameTextArea.focus();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
respondToRenameVisibility(renameTextArea, handleVisibilityChange);
|
||||||
|
|
||||||
//------------------------------------------------
|
//------------------------------------------------
|
||||||
// Adjust the chat tab margin if no extension UI
|
// Adjust the chat tab margin if no extension UI
|
||||||
|
|
@ -494,38 +561,6 @@ document.querySelectorAll(".focus-on-chat-input").forEach(element => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
//------------------------------------------------
|
|
||||||
// "New chat" hover menu with incognito option
|
|
||||||
//------------------------------------------------
|
|
||||||
|
|
||||||
(function() {
|
|
||||||
const newChatBtn = document.getElementById("new-chat-btn");
|
|
||||||
|
|
||||||
const wrapper = document.createElement("div");
|
|
||||||
wrapper.id = "new-chat-wrapper";
|
|
||||||
newChatBtn.replaceWith(wrapper);
|
|
||||||
wrapper.appendChild(newChatBtn);
|
|
||||||
|
|
||||||
const arrow = document.createElement("span");
|
|
||||||
arrow.className = "new-chat-arrow";
|
|
||||||
arrow.textContent = "\u25BE";
|
|
||||||
|
|
||||||
const menu = document.createElement("div");
|
|
||||||
menu.className = "new-chat-menu";
|
|
||||||
const option = document.createElement("div");
|
|
||||||
option.className = "new-chat-menu-item";
|
|
||||||
option.textContent = "Incognito chat";
|
|
||||||
menu.appendChild(option);
|
|
||||||
|
|
||||||
arrow.appendChild(menu);
|
|
||||||
wrapper.appendChild(arrow);
|
|
||||||
|
|
||||||
option.addEventListener("click", function(e) {
|
|
||||||
e.stopPropagation();
|
|
||||||
document.querySelector("#incognito-chat-btn").click();
|
|
||||||
});
|
|
||||||
})();
|
|
||||||
|
|
||||||
//------------------------------------------------
|
//------------------------------------------------
|
||||||
// Fix a border around the "past chats" menu
|
// Fix a border around the "past chats" menu
|
||||||
//------------------------------------------------
|
//------------------------------------------------
|
||||||
|
|
@ -679,21 +714,21 @@ function handleIndividualSidebarClose(event) {
|
||||||
|
|
||||||
// Close navigation bar if click is outside and it is open
|
// Close navigation bar if click is outside and it is open
|
||||||
if (!headerBar.contains(target) && !headerBar.classList.contains("sidebar-hidden")) {
|
if (!headerBar.contains(target) && !headerBar.classList.contains("sidebar-hidden")) {
|
||||||
toggleSidebar(headerBar, navigationToggle);
|
toggleSidebar(headerBar, navigationToggle, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close past chats row if click is outside and it is open
|
// Close past chats row if click is outside and it is open
|
||||||
if (!pastChatsRow.contains(target) && !pastChatsRow.classList.contains("sidebar-hidden")) {
|
if (!pastChatsRow.contains(target) && !pastChatsRow.classList.contains("sidebar-hidden")) {
|
||||||
toggleSidebar(pastChatsRow, pastChatsToggle);
|
toggleSidebar(pastChatsRow, pastChatsToggle, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close chat controls row if click is outside and it is open
|
// Close chat controls row if click is outside and it is open
|
||||||
if (!chatControlsRow.contains(target) && !chatControlsRow.classList.contains("sidebar-hidden")) {
|
if (!chatControlsRow.contains(target) && !chatControlsRow.classList.contains("sidebar-hidden")) {
|
||||||
toggleSidebar(chatControlsRow, chatControlsToggle);
|
toggleSidebar(chatControlsRow, chatControlsToggle, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function toggleSidebar(sidebar, toggle) {
|
function toggleSidebar(sidebar, toggle, forceClose = false) {
|
||||||
const isCurrentlyHidden = sidebar.classList.contains("sidebar-hidden");
|
const isCurrentlyHidden = sidebar.classList.contains("sidebar-hidden");
|
||||||
const shouldClose = !isCurrentlyHidden;
|
const shouldClose = !isCurrentlyHidden;
|
||||||
|
|
||||||
|
|
@ -718,6 +753,11 @@ function toggleSidebar(sidebar, toggle) {
|
||||||
toggle.classList.toggle("chat-controls-open", !shouldClose);
|
toggle.classList.toggle("chat-controls-open", !shouldClose);
|
||||||
toggle.innerHTML = shouldClose ? leftArrowSVG : rightArrowSVG;
|
toggle.innerHTML = shouldClose ? leftArrowSVG : rightArrowSVG;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mobile handling
|
||||||
|
if (isMobile()) {
|
||||||
|
sidebar.classList.toggle("sidebar-shown", !shouldClose);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function to check if the device is mobile
|
// Function to check if the device is mobile
|
||||||
|
|
@ -777,17 +817,17 @@ pastChatsToggle.addEventListener("click", () => {
|
||||||
const isCurrentlyOpen = !pastChatsRow.classList.contains("sidebar-hidden");
|
const isCurrentlyOpen = !pastChatsRow.classList.contains("sidebar-hidden");
|
||||||
toggleSidebar(pastChatsRow, pastChatsToggle);
|
toggleSidebar(pastChatsRow, pastChatsToggle);
|
||||||
|
|
||||||
// On desktop, sync both sidebars together
|
// On desktop, open/close both sidebars at the same time
|
||||||
if (!isMobile()) {
|
if (!isMobile()) {
|
||||||
if (isCurrentlyOpen) {
|
if (isCurrentlyOpen) {
|
||||||
// If we just closed the left sidebar, also close the right sidebar
|
// If we just closed the left sidebar, also close the right sidebar
|
||||||
if (!chatControlsRow.classList.contains("sidebar-hidden")) {
|
if (!chatControlsRow.classList.contains("sidebar-hidden")) {
|
||||||
toggleSidebar(chatControlsRow, chatControlsToggle);
|
toggleSidebar(chatControlsRow, chatControlsToggle, true);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// If we just opened the left sidebar, also open the right sidebar
|
// If we just opened the left sidebar, also open the right sidebar
|
||||||
if (chatControlsRow.classList.contains("sidebar-hidden")) {
|
if (chatControlsRow.classList.contains("sidebar-hidden")) {
|
||||||
toggleSidebar(chatControlsRow, chatControlsToggle);
|
toggleSidebar(chatControlsRow, chatControlsToggle, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -797,17 +837,17 @@ chatControlsToggle.addEventListener("click", () => {
|
||||||
const isCurrentlyOpen = !chatControlsRow.classList.contains("sidebar-hidden");
|
const isCurrentlyOpen = !chatControlsRow.classList.contains("sidebar-hidden");
|
||||||
toggleSidebar(chatControlsRow, chatControlsToggle);
|
toggleSidebar(chatControlsRow, chatControlsToggle);
|
||||||
|
|
||||||
// On desktop, sync both sidebars together
|
// On desktop, open/close both sidebars at the same time
|
||||||
if (!isMobile()) {
|
if (!isMobile()) {
|
||||||
if (isCurrentlyOpen) {
|
if (isCurrentlyOpen) {
|
||||||
// If we just closed the right sidebar, also close the left sidebar
|
// If we just closed the right sidebar, also close the left sidebar
|
||||||
if (!pastChatsRow.classList.contains("sidebar-hidden")) {
|
if (!pastChatsRow.classList.contains("sidebar-hidden")) {
|
||||||
toggleSidebar(pastChatsRow, pastChatsToggle);
|
toggleSidebar(pastChatsRow, pastChatsToggle, true);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// If we just opened the right sidebar, also open the left sidebar
|
// If we just opened the right sidebar, also open the left sidebar
|
||||||
if (pastChatsRow.classList.contains("sidebar-hidden")) {
|
if (pastChatsRow.classList.contains("sidebar-hidden")) {
|
||||||
toggleSidebar(pastChatsRow, pastChatsToggle);
|
toggleSidebar(pastChatsRow, pastChatsToggle, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -827,7 +867,7 @@ if (isMobile()) {
|
||||||
const textarea = document.querySelector("#chat-input textarea");
|
const textarea = document.querySelector("#chat-input textarea");
|
||||||
|
|
||||||
if (textarea) {
|
if (textarea) {
|
||||||
// Force textarea height recalculation by simulating content change
|
// Simulate adding and removing a newline
|
||||||
textarea.value += "\n";
|
textarea.value += "\n";
|
||||||
textarea.dispatchEvent(new Event("input", { bubbles: true }));
|
textarea.dispatchEvent(new Event("input", { bubbles: true }));
|
||||||
textarea.value = textarea.value.slice(0, -1);
|
textarea.value = textarea.value.slice(0, -1);
|
||||||
|
|
@ -1050,13 +1090,15 @@ document.fonts.addEventListener("loadingdone", (event) => {
|
||||||
const currentHeight = chatInputRow.offsetHeight;
|
const currentHeight = chatInputRow.offsetHeight;
|
||||||
const heightDifference = currentHeight - originalHeight;
|
const heightDifference = currentHeight - originalHeight;
|
||||||
chatParent.style.marginBottom = `${originalMarginBottom + heightDifference}px`;
|
chatParent.style.marginBottom = `${originalMarginBottom + heightDifference}px`;
|
||||||
if (!window.isScrolled) {
|
|
||||||
chatParent.scrollTop = chatParent.scrollHeight - chatParent.clientHeight;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Watch for size changes that affect height
|
// Watch for changes that might affect height
|
||||||
new ResizeObserver(updateMargin).observe(chatInputRow);
|
const observer = new MutationObserver(updateMargin);
|
||||||
|
observer.observe(chatInputRow, {
|
||||||
|
childList: true,
|
||||||
|
subtree: true,
|
||||||
|
attributes: true
|
||||||
|
});
|
||||||
|
|
||||||
// Also listen for window resize
|
// Also listen for window resize
|
||||||
window.addEventListener("resize", updateMargin);
|
window.addEventListener("resize", updateMargin);
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
// Functions for downloading JSON files
|
// Functions for downloading JSON files
|
||||||
function getCurrentTimestamp() {
|
function getCurrentTimestamp() {
|
||||||
const now = new Date();
|
const now = new Date();
|
||||||
const timezoneOffset = now.getTimezoneOffset() * 60000; // Convert minutes to milliseconds
|
const timezoneOffset = now.getTimezoneOffset() * 60000; // Convert to milliseconds
|
||||||
const localTime = new Date(now.getTime() - timezoneOffset);
|
const localTime = new Date(now.getTime() - timezoneOffset);
|
||||||
return localTime.toISOString().replace(/[-:]/g, "").slice(0, 15);
|
const formattedTimestamp = localTime.toISOString().replace(/[-:]/g, "").slice(0, 15);
|
||||||
|
return formattedTimestamp;
|
||||||
}
|
}
|
||||||
|
|
||||||
function saveFile(contents, filename) {
|
function saveFile(contents, filename) {
|
||||||
|
|
@ -17,18 +18,23 @@ function saveFile(contents, filename) {
|
||||||
}
|
}
|
||||||
|
|
||||||
function saveHistory(history, character, mode) {
|
function saveHistory(history, character, mode) {
|
||||||
let path;
|
let path = null;
|
||||||
|
|
||||||
if (["chat", "chat-instruct"].includes(mode) && character && character.trim() !== "") {
|
if (["chat", "chat-instruct"].includes(mode) && character && character.trim() !== "") {
|
||||||
path = `history_${character}_${getCurrentTimestamp()}.json`;
|
path = `history_${character}_${getCurrentTimestamp()}.json`;
|
||||||
} else {
|
} else {
|
||||||
path = `history_${mode || "unknown"}_${getCurrentTimestamp()}.json`;
|
try {
|
||||||
|
path = `history_${mode}_${getCurrentTimestamp()}.json`;
|
||||||
|
} catch (error) {
|
||||||
|
path = `history_${getCurrentTimestamp()}.json`;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
saveFile(history, path);
|
saveFile(history, path);
|
||||||
}
|
}
|
||||||
|
|
||||||
function saveSession(session) {
|
function saveSession(session) {
|
||||||
const path = `session_${getCurrentTimestamp()}.json`;
|
let path = null;
|
||||||
|
|
||||||
|
path = `session_${getCurrentTimestamp()}.json`;
|
||||||
saveFile(session, path);
|
saveFile(session, path);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,13 @@
|
||||||
|
const chatParent = document.querySelector(".chat-parent");
|
||||||
|
|
||||||
function toggle_controls(value) {
|
function toggle_controls(value) {
|
||||||
const navToggle = document.getElementById("navigation-toggle");
|
|
||||||
const pastChatsToggle = document.getElementById("past-chats-toggle");
|
|
||||||
const extensions = document.querySelector("#extensions");
|
const extensions = document.querySelector("#extensions");
|
||||||
const galleryExtension = document.getElementById("gallery-extension");
|
|
||||||
|
|
||||||
if (value) {
|
if (value) {
|
||||||
// SHOW MODE: Click toggles to show hidden sidebars
|
// SHOW MODE: Click toggles to show hidden sidebars
|
||||||
|
const navToggle = document.getElementById("navigation-toggle");
|
||||||
|
const pastChatsToggle = document.getElementById("past-chats-toggle");
|
||||||
|
|
||||||
if (navToggle && document.querySelector(".header_bar")?.classList.contains("sidebar-hidden")) {
|
if (navToggle && document.querySelector(".header_bar")?.classList.contains("sidebar-hidden")) {
|
||||||
navToggle.click();
|
navToggle.click();
|
||||||
}
|
}
|
||||||
|
|
@ -17,11 +19,17 @@ function toggle_controls(value) {
|
||||||
if (extensions) {
|
if (extensions) {
|
||||||
extensions.style.display = "inherit";
|
extensions.style.display = "inherit";
|
||||||
}
|
}
|
||||||
if (galleryExtension) {
|
|
||||||
galleryExtension.style.display = "block";
|
let gallery_element = document.getElementById("gallery-extension");
|
||||||
|
if (gallery_element) {
|
||||||
|
gallery_element.style.display = "block";
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// HIDE MODE: Click toggles to hide visible sidebars
|
// HIDE MODE: Click toggles to hide visible sidebars
|
||||||
|
const navToggle = document.getElementById("navigation-toggle");
|
||||||
|
const pastChatsToggle = document.getElementById("past-chats-toggle");
|
||||||
|
|
||||||
if (navToggle && !document.querySelector(".header_bar")?.classList.contains("sidebar-hidden")) {
|
if (navToggle && !document.querySelector(".header_bar")?.classList.contains("sidebar-hidden")) {
|
||||||
navToggle.click();
|
navToggle.click();
|
||||||
}
|
}
|
||||||
|
|
@ -33,8 +41,5 @@ function toggle_controls(value) {
|
||||||
if (extensions) {
|
if (extensions) {
|
||||||
extensions.style.display = "none";
|
extensions.style.display = "none";
|
||||||
}
|
}
|
||||||
if (galleryExtension) {
|
|
||||||
galleryExtension.style.display = "none";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,17 @@ function scrollToTop() {
|
||||||
window.scrollTo({ top: 0 });
|
window.scrollTo({ top: 0 });
|
||||||
}
|
}
|
||||||
|
|
||||||
function findButtonsByText(buttonText, container = document) {
|
function findButtonsByText(buttonText) {
|
||||||
return Array.from(container.getElementsByTagName("button"))
|
const buttons = document.getElementsByTagName("button");
|
||||||
.filter(btn => btn.textContent.trim() === buttonText);
|
const matchingButtons = [];
|
||||||
|
|
||||||
|
for (let i = 0; i < buttons.length; i++) {
|
||||||
|
if (buttons[i].textContent.trim() === buttonText) {
|
||||||
|
matchingButtons.push(buttons[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return matchingButtons;
|
||||||
}
|
}
|
||||||
|
|
||||||
function switch_to_chat() {
|
function switch_to_chat() {
|
||||||
|
|
@ -31,9 +39,13 @@ function switch_to_character() {
|
||||||
|
|
||||||
function switch_to_image_ai_generate() {
|
function switch_to_image_ai_generate() {
|
||||||
const container = document.querySelector("#image-ai-tab");
|
const container = document.querySelector("#image-ai-tab");
|
||||||
const generateBtn = findButtonsByText("Generate", container)[0];
|
const buttons = container.getElementsByTagName("button");
|
||||||
if (generateBtn) {
|
|
||||||
generateBtn.click();
|
for (let i = 0; i < buttons.length; i++) {
|
||||||
|
if (buttons[i].textContent.trim() === "Generate") {
|
||||||
|
buttons[i].click();
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
scrollToTop();
|
scrollToTop();
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
function updateBigPicture() {
|
function updateBigPicture() {
|
||||||
var existingElement = document.querySelector(".bigProfilePicture");
|
var existingElement = document.querySelector(".bigProfilePicture");
|
||||||
if (existingElement) {
|
if (existingElement) {
|
||||||
existingElement.src = getProfilePictureUrl();
|
var timestamp = new Date().getTime();
|
||||||
|
existingElement.src = "/file/user_data/cache/pfp_character.png?time=" + timestamp;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,468 +0,0 @@
|
||||||
import json
|
|
||||||
import time
|
|
||||||
|
|
||||||
from modules import shared
|
|
||||||
|
|
||||||
|
|
||||||
def convert_request(body: dict) -> dict:
|
|
||||||
"""Transform Anthropic Messages API body into the dict that chat_completions_common expects."""
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
# System message
|
|
||||||
system = body.get('system')
|
|
||||||
if system:
|
|
||||||
if isinstance(system, list):
|
|
||||||
# List of content blocks like [{"type":"text","text":"..."}]
|
|
||||||
text_parts = [block.get('text', '') for block in system if isinstance(block, dict) and block.get('type') == 'text']
|
|
||||||
system_text = '\n'.join(text_parts)
|
|
||||||
else:
|
|
||||||
system_text = str(system)
|
|
||||||
if system_text:
|
|
||||||
messages.append({"role": "system", "content": system_text})
|
|
||||||
|
|
||||||
# Convert messages
|
|
||||||
for msg in body.get('messages', []):
|
|
||||||
role = msg.get('role')
|
|
||||||
content = msg.get('content')
|
|
||||||
|
|
||||||
if isinstance(content, str):
|
|
||||||
messages.append({"role": role, "content": content})
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not isinstance(content, list):
|
|
||||||
messages.append({"role": role, "content": str(content) if content else ""})
|
|
||||||
continue
|
|
||||||
|
|
||||||
if role == 'assistant':
|
|
||||||
# Split into text content, tool_calls, and skip thinking blocks
|
|
||||||
text_parts = []
|
|
||||||
tool_calls = []
|
|
||||||
for block in content:
|
|
||||||
btype = block.get('type')
|
|
||||||
if btype == 'text':
|
|
||||||
text_parts.append(block.get('text', ''))
|
|
||||||
elif btype == 'tool_use':
|
|
||||||
tool_calls.append({
|
|
||||||
"id": block.get('id', ''),
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": block.get('name', ''),
|
|
||||||
"arguments": json.dumps(block.get('input', {}))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
elif btype == 'thinking':
|
|
||||||
pass # Strip thinking blocks
|
|
||||||
|
|
||||||
assistant_msg = {"role": "assistant", "content": '\n'.join(text_parts) if text_parts else ""}
|
|
||||||
if tool_calls:
|
|
||||||
assistant_msg["tool_calls"] = tool_calls
|
|
||||||
messages.append(assistant_msg)
|
|
||||||
|
|
||||||
elif role == 'user':
|
|
||||||
# Handle tool_result blocks and regular content
|
|
||||||
regular_parts = []
|
|
||||||
for block in content:
|
|
||||||
btype = block.get('type')
|
|
||||||
if btype == 'tool_result':
|
|
||||||
# Emit any accumulated regular content first
|
|
||||||
if regular_parts:
|
|
||||||
if len(regular_parts) == 1 and regular_parts[0].get('type') == 'text':
|
|
||||||
messages.append({"role": "user", "content": regular_parts[0]['text']})
|
|
||||||
else:
|
|
||||||
messages.append({"role": "user", "content": regular_parts})
|
|
||||||
regular_parts = []
|
|
||||||
# Convert tool_result to OpenAI tool message
|
|
||||||
tool_content = block.get('content', '')
|
|
||||||
if isinstance(tool_content, list):
|
|
||||||
tool_content = '\n'.join(
|
|
||||||
b.get('text', '') for b in tool_content
|
|
||||||
if isinstance(b, dict) and b.get('type') == 'text'
|
|
||||||
)
|
|
||||||
messages.append({
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": block.get('tool_use_id', ''),
|
|
||||||
"content": str(tool_content)
|
|
||||||
})
|
|
||||||
elif btype == 'text':
|
|
||||||
regular_parts.append({"type": "text", "text": block.get('text', '')})
|
|
||||||
elif btype == 'image':
|
|
||||||
source = block.get('source', {})
|
|
||||||
if source.get('type') == 'base64':
|
|
||||||
media_type = source.get('media_type', 'image/png')
|
|
||||||
data = source.get('data', '')
|
|
||||||
regular_parts.append({
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {"url": f"data:{media_type};base64,{data}"}
|
|
||||||
})
|
|
||||||
elif btype == 'thinking':
|
|
||||||
pass # Strip thinking blocks
|
|
||||||
|
|
||||||
if regular_parts:
|
|
||||||
if len(regular_parts) == 1 and regular_parts[0].get('type') == 'text':
|
|
||||||
messages.append({"role": "user", "content": regular_parts[0]['text']})
|
|
||||||
else:
|
|
||||||
messages.append({"role": "user", "content": regular_parts})
|
|
||||||
else:
|
|
||||||
messages.append({"role": role, "content": str(content)})
|
|
||||||
|
|
||||||
# Start with all fields from the original body (includes GenerationOptions defaults)
|
|
||||||
result = dict(body)
|
|
||||||
|
|
||||||
# Remove Anthropic-specific fields that don't map directly
|
|
||||||
for key in ('system', 'stop_sequences', 'tools', 'tool_choice', 'thinking', 'metadata'):
|
|
||||||
result.pop(key, None)
|
|
||||||
|
|
||||||
# Set converted fields
|
|
||||||
result['messages'] = messages
|
|
||||||
result['max_tokens'] = body.get('max_tokens', 4096)
|
|
||||||
result['stream'] = body.get('stream', False)
|
|
||||||
result['mode'] = 'instruct'
|
|
||||||
|
|
||||||
# Ensure ChatCompletionRequestParams defaults are present
|
|
||||||
result.setdefault('continue_', False)
|
|
||||||
result.setdefault('instruction_template', None)
|
|
||||||
result.setdefault('instruction_template_str', None)
|
|
||||||
result.setdefault('character', None)
|
|
||||||
result.setdefault('bot_name', None)
|
|
||||||
result.setdefault('context', None)
|
|
||||||
result.setdefault('greeting', None)
|
|
||||||
result.setdefault('user_name', None)
|
|
||||||
result.setdefault('user_bio', None)
|
|
||||||
result.setdefault('chat_template_str', None)
|
|
||||||
result.setdefault('chat_instruct_command', 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>')
|
|
||||||
result.setdefault('frequency_penalty', None)
|
|
||||||
result.setdefault('presence_penalty', None)
|
|
||||||
result.setdefault('logit_bias', None)
|
|
||||||
result.setdefault('logprobs', None)
|
|
||||||
result.setdefault('top_logprobs', None)
|
|
||||||
result.setdefault('n', 1)
|
|
||||||
result.setdefault('model', None)
|
|
||||||
result.setdefault('functions', None)
|
|
||||||
result.setdefault('function_call', None)
|
|
||||||
result.setdefault('stream_options', None)
|
|
||||||
result.setdefault('user', None)
|
|
||||||
result.setdefault('stop', None)
|
|
||||||
result.setdefault('tool_choice', None)
|
|
||||||
|
|
||||||
# Always request usage in streaming so the usage-only chunk triggers
|
|
||||||
# the deferred message_delta/message_stop with accurate output_tokens
|
|
||||||
if body.get('stream', False):
|
|
||||||
result['stream_options'] = {'include_usage': True}
|
|
||||||
|
|
||||||
# Map stop_sequences -> stop
|
|
||||||
if body.get('stop_sequences'):
|
|
||||||
result['stop'] = body['stop_sequences']
|
|
||||||
|
|
||||||
# Tools
|
|
||||||
if body.get('tools'):
|
|
||||||
result['tools'] = [
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": t.get('name', ''),
|
|
||||||
"description": t.get('description', ''),
|
|
||||||
"parameters": t.get('input_schema', {"type": "object", "properties": {}})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for t in body['tools']
|
|
||||||
]
|
|
||||||
|
|
||||||
# Tool choice
|
|
||||||
tc = body.get('tool_choice')
|
|
||||||
if tc and isinstance(tc, dict):
|
|
||||||
tc_type = tc.get('type')
|
|
||||||
if tc_type == 'auto':
|
|
||||||
result['tool_choice'] = 'auto'
|
|
||||||
elif tc_type == 'any':
|
|
||||||
result['tool_choice'] = 'required'
|
|
||||||
elif tc_type == 'tool':
|
|
||||||
result['tool_choice'] = {"type": "function", "function": {"name": tc.get('name', '')}}
|
|
||||||
elif tc_type == 'none':
|
|
||||||
result['tool_choice'] = 'none'
|
|
||||||
else:
|
|
||||||
result.setdefault('tool_choice', None)
|
|
||||||
|
|
||||||
# Thinking
|
|
||||||
thinking = body.get('thinking')
|
|
||||||
if thinking and isinstance(thinking, dict) and thinking.get('type') in ('enabled', 'adaptive'):
|
|
||||||
result['enable_thinking'] = True
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
_FINISH_REASON_MAP = {
|
|
||||||
"stop": "end_turn",
|
|
||||||
"length": "max_tokens",
|
|
||||||
"tool_calls": "tool_use",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def build_response(openai_resp: dict, model: str) -> dict:
|
|
||||||
"""Transform OpenAI chat completion response dict into Anthropic Messages format."""
|
|
||||||
resp_id = openai_resp.get('id', 'msg_unknown')
|
|
||||||
if resp_id.startswith('chatcmpl-'):
|
|
||||||
resp_id = 'msg_' + resp_id[9:]
|
|
||||||
|
|
||||||
choice = openai_resp.get('choices', [{}])[0]
|
|
||||||
message = choice.get('message', {})
|
|
||||||
|
|
||||||
content = []
|
|
||||||
|
|
||||||
# Reasoning/thinking content
|
|
||||||
reasoning = message.get('reasoning_content')
|
|
||||||
if reasoning:
|
|
||||||
content.append({"type": "thinking", "thinking": reasoning, "signature": ""})
|
|
||||||
|
|
||||||
# Text content
|
|
||||||
text = message.get('content')
|
|
||||||
if text:
|
|
||||||
content.append({"type": "text", "text": text})
|
|
||||||
|
|
||||||
# Tool calls
|
|
||||||
tool_calls = message.get('tool_calls')
|
|
||||||
if tool_calls:
|
|
||||||
for tc in tool_calls:
|
|
||||||
func = tc.get('function', {})
|
|
||||||
try:
|
|
||||||
input_data = json.loads(func.get('arguments', '{}'))
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
input_data = {}
|
|
||||||
content.append({
|
|
||||||
"type": "tool_use",
|
|
||||||
"id": tc.get('id', ''),
|
|
||||||
"name": func.get('name', ''),
|
|
||||||
"input": input_data
|
|
||||||
})
|
|
||||||
|
|
||||||
finish_reason = choice.get('finish_reason', 'stop')
|
|
||||||
stop_reason = _FINISH_REASON_MAP.get(finish_reason, 'end_turn')
|
|
||||||
|
|
||||||
usage = openai_resp.get('usage', {})
|
|
||||||
|
|
||||||
return {
|
|
||||||
"id": resp_id,
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"content": content,
|
|
||||||
"model": model,
|
|
||||||
"stop_reason": stop_reason,
|
|
||||||
"stop_sequence": None,
|
|
||||||
"usage": {
|
|
||||||
"input_tokens": usage.get('prompt_tokens', 0),
|
|
||||||
"output_tokens": usage.get('completion_tokens', 0),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class StreamConverter:
|
|
||||||
"""Stateful converter: processes one OpenAI chunk at a time, yields Anthropic SSE events.
|
|
||||||
|
|
||||||
When include_usage is enabled in the OpenAI request, the final chunk with
|
|
||||||
finish_reason has usage=None, followed by a separate usage-only chunk
|
|
||||||
(choices=[], usage={...}). We defer emitting message_delta and message_stop
|
|
||||||
until we receive that usage chunk so output_tokens is accurate.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model: str):
|
|
||||||
self.model = model
|
|
||||||
self.msg_id = "msg_%d" % int(time.time() * 1000000000)
|
|
||||||
self.block_index = 0
|
|
||||||
self.in_thinking = False
|
|
||||||
self.in_text = False
|
|
||||||
self.input_tokens = 0
|
|
||||||
self.output_tokens = 0
|
|
||||||
self.tool_calls_accum = {}
|
|
||||||
self.stop_reason = "end_turn"
|
|
||||||
self._pending_finish = False # True after we've seen finish_reason
|
|
||||||
|
|
||||||
def process_chunk(self, chunk: dict) -> list[dict]:
|
|
||||||
"""Process a single OpenAI streaming chunk; return list of Anthropic SSE event dicts."""
|
|
||||||
events = []
|
|
||||||
choices = chunk.get('choices', [])
|
|
||||||
usage = chunk.get('usage')
|
|
||||||
|
|
||||||
if usage:
|
|
||||||
self.input_tokens = usage.get('prompt_tokens', self.input_tokens)
|
|
||||||
self.output_tokens = usage.get('completion_tokens', self.output_tokens)
|
|
||||||
|
|
||||||
# Usage-only chunk (choices=[]) arrives after the finish chunk
|
|
||||||
if not choices:
|
|
||||||
if self._pending_finish:
|
|
||||||
events.extend(self.finish())
|
|
||||||
return events
|
|
||||||
|
|
||||||
choice = choices[0]
|
|
||||||
delta = choice.get('delta', {})
|
|
||||||
finish_reason = choice.get('finish_reason')
|
|
||||||
|
|
||||||
# First chunk with role
|
|
||||||
if 'role' in delta:
|
|
||||||
events.append({
|
|
||||||
"event": "message_start",
|
|
||||||
"data": json.dumps({
|
|
||||||
"type": "message_start",
|
|
||||||
"message": {
|
|
||||||
"id": self.msg_id,
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [],
|
|
||||||
"model": self.model,
|
|
||||||
"stop_reason": None,
|
|
||||||
"stop_sequence": None,
|
|
||||||
"usage": {"input_tokens": self.input_tokens, "output_tokens": 0}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
events.append({"event": "ping", "data": json.dumps({"type": "ping"})})
|
|
||||||
return events
|
|
||||||
|
|
||||||
# Reasoning content
|
|
||||||
reasoning_content = delta.get('reasoning_content')
|
|
||||||
if reasoning_content:
|
|
||||||
if not self.in_thinking:
|
|
||||||
self.in_thinking = True
|
|
||||||
events.append({
|
|
||||||
"event": "content_block_start",
|
|
||||||
"data": json.dumps({
|
|
||||||
"type": "content_block_start",
|
|
||||||
"index": self.block_index,
|
|
||||||
"content_block": {"type": "thinking", "thinking": "", "signature": ""}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
events.append({
|
|
||||||
"event": "content_block_delta",
|
|
||||||
"data": json.dumps({
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": self.block_index,
|
|
||||||
"delta": {"type": "thinking_delta", "thinking": reasoning_content}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
return events
|
|
||||||
|
|
||||||
# Text content
|
|
||||||
text_content = delta.get('content')
|
|
||||||
if text_content:
|
|
||||||
if self.in_thinking:
|
|
||||||
events.append({
|
|
||||||
"event": "content_block_stop",
|
|
||||||
"data": json.dumps({"type": "content_block_stop", "index": self.block_index})
|
|
||||||
})
|
|
||||||
self.in_thinking = False
|
|
||||||
self.block_index += 1
|
|
||||||
|
|
||||||
if not self.in_text:
|
|
||||||
self.in_text = True
|
|
||||||
events.append({
|
|
||||||
"event": "content_block_start",
|
|
||||||
"data": json.dumps({
|
|
||||||
"type": "content_block_start",
|
|
||||||
"index": self.block_index,
|
|
||||||
"content_block": {"type": "text", "text": ""}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
events.append({
|
|
||||||
"event": "content_block_delta",
|
|
||||||
"data": json.dumps({
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": self.block_index,
|
|
||||||
"delta": {"type": "text_delta", "text": text_content}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
return events
|
|
||||||
|
|
||||||
# Tool calls in delta
|
|
||||||
chunk_tool_calls = delta.get('tool_calls')
|
|
||||||
if chunk_tool_calls:
|
|
||||||
for tc in chunk_tool_calls:
|
|
||||||
tc_id = tc.get('id', '')
|
|
||||||
tc_idx = tc.get('index', 0)
|
|
||||||
func = tc.get('function', {})
|
|
||||||
if tc_id:
|
|
||||||
self.tool_calls_accum[tc_idx] = {
|
|
||||||
"id": tc_id,
|
|
||||||
"name": func.get('name', ''),
|
|
||||||
"arguments": func.get('arguments', '')
|
|
||||||
}
|
|
||||||
elif tc_idx in self.tool_calls_accum:
|
|
||||||
self.tool_calls_accum[tc_idx]["arguments"] += func.get('arguments', '')
|
|
||||||
|
|
||||||
# Final chunk — close open content blocks, defer message_delta/stop for usage
|
|
||||||
if finish_reason is not None:
|
|
||||||
self.stop_reason = _FINISH_REASON_MAP.get(finish_reason, 'end_turn')
|
|
||||||
|
|
||||||
if self.in_thinking:
|
|
||||||
events.append({
|
|
||||||
"event": "content_block_stop",
|
|
||||||
"data": json.dumps({"type": "content_block_stop", "index": self.block_index})
|
|
||||||
})
|
|
||||||
self.in_thinking = False
|
|
||||||
self.block_index += 1
|
|
||||||
|
|
||||||
if self.in_text:
|
|
||||||
events.append({
|
|
||||||
"event": "content_block_stop",
|
|
||||||
"data": json.dumps({"type": "content_block_stop", "index": self.block_index})
|
|
||||||
})
|
|
||||||
self.in_text = False
|
|
||||||
self.block_index += 1
|
|
||||||
|
|
||||||
for tc_idx in sorted(self.tool_calls_accum.keys()):
|
|
||||||
tc = self.tool_calls_accum[tc_idx]
|
|
||||||
arguments_str = tc["arguments"] or "{}"
|
|
||||||
|
|
||||||
events.append({
|
|
||||||
"event": "content_block_start",
|
|
||||||
"data": json.dumps({
|
|
||||||
"type": "content_block_start",
|
|
||||||
"index": self.block_index,
|
|
||||||
"content_block": {
|
|
||||||
"type": "tool_use",
|
|
||||||
"id": tc["id"],
|
|
||||||
"name": tc["name"],
|
|
||||||
"input": {}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
# Emit the full input as a single input_json_delta so SDK
|
|
||||||
# clients that reconstruct from deltas get the correct data
|
|
||||||
events.append({
|
|
||||||
"event": "content_block_delta",
|
|
||||||
"data": json.dumps({
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": self.block_index,
|
|
||||||
"delta": {
|
|
||||||
"type": "input_json_delta",
|
|
||||||
"partial_json": arguments_str
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
events.append({
|
|
||||||
"event": "content_block_stop",
|
|
||||||
"data": json.dumps({"type": "content_block_stop", "index": self.block_index})
|
|
||||||
})
|
|
||||||
self.block_index += 1
|
|
||||||
|
|
||||||
# Defer message_delta/stop — usage chunk may follow
|
|
||||||
self._pending_finish = True
|
|
||||||
|
|
||||||
return events
|
|
||||||
|
|
||||||
def finish(self) -> list[dict]:
|
|
||||||
"""Emit deferred message_delta and message_stop. Safe to call multiple times."""
|
|
||||||
if not self._pending_finish:
|
|
||||||
return []
|
|
||||||
self._pending_finish = False
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"event": "message_delta",
|
|
||||||
"data": json.dumps({
|
|
||||||
"type": "message_delta",
|
|
||||||
"delta": {"stop_reason": self.stop_reason, "stop_sequence": None},
|
|
||||||
"usage": {"input_tokens": self.input_tokens, "output_tokens": self.output_tokens}
|
|
||||||
})
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"event": "message_stop",
|
|
||||||
"data": json.dumps({"type": "message_stop"})
|
|
||||||
}
|
|
||||||
]
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,53 +0,0 @@
|
||||||
import base64
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def float_list_to_base64(float_array: np.ndarray) -> str:
|
|
||||||
# Convert the list to a float32 array that the OpenAPI client expects
|
|
||||||
# float_array = np.array(float_list, dtype="float32")
|
|
||||||
|
|
||||||
# Get raw bytes
|
|
||||||
bytes_array = float_array.tobytes()
|
|
||||||
|
|
||||||
# Encode bytes into base64
|
|
||||||
encoded_bytes = base64.b64encode(bytes_array)
|
|
||||||
|
|
||||||
# Turn raw base64 encoded bytes into ASCII
|
|
||||||
ascii_string = encoded_bytes.decode('ascii')
|
|
||||||
return ascii_string
|
|
||||||
|
|
||||||
|
|
||||||
def debug_msg(*args, **kwargs):
|
|
||||||
if int(os.environ.get("OPENEDAI_DEBUG", 0)):
|
|
||||||
print(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
|
|
||||||
try:
|
|
||||||
from flask_cloudflared import _run_cloudflared
|
|
||||||
except ImportError:
|
|
||||||
print('You should install flask_cloudflared manually')
|
|
||||||
raise Exception(
|
|
||||||
'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.')
|
|
||||||
|
|
||||||
for _ in range(max_attempts):
|
|
||||||
try:
|
|
||||||
if tunnel_id is not None:
|
|
||||||
public_url = _run_cloudflared(port, port + 1, tunnel_id=tunnel_id)
|
|
||||||
else:
|
|
||||||
public_url = _run_cloudflared(port, port + 1)
|
|
||||||
|
|
||||||
if on_start:
|
|
||||||
on_start(public_url)
|
|
||||||
|
|
||||||
return
|
|
||||||
except Exception:
|
|
||||||
traceback.print_exc()
|
|
||||||
time.sleep(3)
|
|
||||||
|
|
||||||
raise Exception('Could not start cloudflared.')
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
|
import traceback
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.logging_colors import logger
|
|
||||||
|
|
||||||
|
|
||||||
class StopNowException(Exception):
|
class StopNowException(Exception):
|
||||||
|
|
@ -34,11 +34,12 @@ class Iteratorize:
|
||||||
|
|
||||||
def gentask():
|
def gentask():
|
||||||
try:
|
try:
|
||||||
ret = self.mfunc(callback=_callback, *self.args, **self.kwargs)
|
ret = self.mfunc(callback=_callback, *args, **self.kwargs)
|
||||||
except StopNowException:
|
except StopNowException:
|
||||||
pass
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed in generation callback")
|
traceback.print_exc()
|
||||||
|
pass
|
||||||
|
|
||||||
self.q.put(self.sentinel)
|
self.q.put(self.sentinel)
|
||||||
if self.c_callback:
|
if self.c_callback:
|
||||||
|
|
|
||||||
807
modules/chat.py
807
modules/chat.py
File diff suppressed because it is too large
Load diff
|
|
@ -1,6 +1,6 @@
|
||||||
import math
|
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, Tuple
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
|
|
@ -9,7 +9,6 @@ import torch
|
||||||
from exllamav3 import Cache, Config, Generator, Model, Tokenizer
|
from exllamav3 import Cache, Config, Generator, Model, Tokenizer
|
||||||
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
|
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
|
||||||
from exllamav3.generator import Job
|
from exllamav3.generator import Job
|
||||||
from exllamav3.generator.filter import Filter
|
|
||||||
from exllamav3.generator.sampler import (
|
from exllamav3.generator.sampler import (
|
||||||
CustomSampler,
|
CustomSampler,
|
||||||
SS_AdaptiveP,
|
SS_AdaptiveP,
|
||||||
|
|
@ -33,30 +32,8 @@ from modules.text_generation import get_max_prompt_length
|
||||||
try:
|
try:
|
||||||
import flash_attn
|
import flash_attn
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning('Failed to load flash-attention due to the following error:', exc_info=True)
|
logger.warning('Failed to load flash-attention due to the following error:\n')
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
class LogitBiasFilter(Filter):
|
|
||||||
"""Filter subclass that applies a static additive logit bias mask."""
|
|
||||||
|
|
||||||
def __init__(self, tokenizer, logit_bias_dict):
|
|
||||||
super().__init__(tokenizer=tokenizer, trigger_token=None, prefix_str=None, eos_after_completed=False)
|
|
||||||
self.logit_bias_dict = logit_bias_dict
|
|
||||||
self._mask = None
|
|
||||||
|
|
||||||
def reset(self): pass
|
|
||||||
def accept_token(self, token): pass
|
|
||||||
def is_completed(self): return False
|
|
||||||
def use_background_worker(self): return False
|
|
||||||
|
|
||||||
def get_next_logit_mask(self):
|
|
||||||
if self._mask is None:
|
|
||||||
self._mask = torch.zeros((1, self.vocab_size), dtype=self.logits_dtype)
|
|
||||||
for token_id_str, bias in self.logit_bias_dict.items():
|
|
||||||
token_id = int(token_id_str)
|
|
||||||
if 0 <= token_id < self.vocab_size:
|
|
||||||
self._mask[0, token_id] = bias
|
|
||||||
return self._mask
|
|
||||||
|
|
||||||
|
|
||||||
class ConcurrentGenerator:
|
class ConcurrentGenerator:
|
||||||
|
|
@ -76,16 +53,7 @@ class ConcurrentGenerator:
|
||||||
if not self.job_queues:
|
if not self.job_queues:
|
||||||
self.has_jobs.clear()
|
self.has_jobs.clear()
|
||||||
continue
|
continue
|
||||||
try:
|
|
||||||
results = self.generator.iterate()
|
results = self.generator.iterate()
|
||||||
except Exception:
|
|
||||||
logger.exception("Exception in ConcurrentGenerator iterate loop")
|
|
||||||
for q in self.job_queues.values():
|
|
||||||
q.put(None)
|
|
||||||
self.job_queues.clear()
|
|
||||||
self.generator.clear_queue()
|
|
||||||
self.has_jobs.clear()
|
|
||||||
continue
|
|
||||||
for result in results:
|
for result in results:
|
||||||
job = result["job"]
|
job = result["job"]
|
||||||
q = self.job_queues.get(job)
|
q = self.job_queues.get(job)
|
||||||
|
|
@ -121,10 +89,6 @@ class Exllamav3Model:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self) -> torch.device:
|
|
||||||
return torch.device(0)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, path_to_model):
|
def from_pretrained(cls, path_to_model):
|
||||||
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
|
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
|
||||||
|
|
@ -185,21 +149,8 @@ class Exllamav3Model:
|
||||||
load_params['tensor_p'] = True
|
load_params['tensor_p'] = True
|
||||||
load_params['tp_backend'] = shared.args.tp_backend
|
load_params['tp_backend'] = shared.args.tp_backend
|
||||||
|
|
||||||
# Load vision and draft before the main model so autosplit
|
model.load(**load_params)
|
||||||
# accounts for their VRAM usage.
|
tokenizer = Tokenizer.from_config(config)
|
||||||
|
|
||||||
# Load vision model component (ExLlamaV3 native)
|
|
||||||
vision_model = None
|
|
||||||
if "vision_config" in config.config_dict:
|
|
||||||
logger.info("Vision component detected in model config. Attempting to load...")
|
|
||||||
try:
|
|
||||||
vision_model = Model.from_config(config, component="vision")
|
|
||||||
vision_model.load(progressbar=True)
|
|
||||||
logger.info("Vision model loaded successfully.")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Vision model loading failed (multimodal disabled): {e}")
|
|
||||||
else:
|
|
||||||
logger.info("No vision component in model config. Skipping multimodal setup.")
|
|
||||||
|
|
||||||
# Initialize draft model for speculative decoding
|
# Initialize draft model for speculative decoding
|
||||||
draft_model = None
|
draft_model = None
|
||||||
|
|
@ -215,8 +166,23 @@ class Exllamav3Model:
|
||||||
logger.warning(f"Draft model not found at {draft_path}, speculative decoding disabled.")
|
logger.warning(f"Draft model not found at {draft_path}, speculative decoding disabled.")
|
||||||
else:
|
else:
|
||||||
draft_config = Config.from_directory(str(draft_path))
|
draft_config = Config.from_directory(str(draft_path))
|
||||||
|
|
||||||
|
# Set context size for draft model with 256-multiple validation
|
||||||
|
if shared.args.ctx_size_draft > 0:
|
||||||
|
draft_max_tokens = shared.args.ctx_size_draft
|
||||||
|
else:
|
||||||
|
draft_max_tokens = shared.args.ctx_size
|
||||||
|
|
||||||
|
# Validate draft model context size is a multiple of 256
|
||||||
|
if draft_max_tokens % 256 != 0:
|
||||||
|
adjusted_draft_tokens = ((draft_max_tokens // 256) + 1) * 256
|
||||||
|
logger.warning(f"Draft model max_num_tokens must be a multiple of 256. Adjusting from {draft_max_tokens} to {adjusted_draft_tokens}")
|
||||||
|
draft_max_tokens = adjusted_draft_tokens
|
||||||
|
|
||||||
|
draft_config.max_seq_len = draft_max_tokens
|
||||||
|
|
||||||
draft_model = Model.from_config(draft_config)
|
draft_model = Model.from_config(draft_config)
|
||||||
draft_cache = Cache(draft_model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs)
|
draft_cache = Cache(draft_model, max_num_tokens=draft_max_tokens, layer_type=layer_type, **cache_kwargs)
|
||||||
|
|
||||||
draft_load_params = {'progressbar': True}
|
draft_load_params = {'progressbar': True}
|
||||||
if split:
|
if split:
|
||||||
|
|
@ -225,9 +191,18 @@ class Exllamav3Model:
|
||||||
draft_model.load(**draft_load_params)
|
draft_model.load(**draft_load_params)
|
||||||
logger.info(f"Draft model loaded successfully. Max speculative tokens: {shared.args.draft_max}")
|
logger.info(f"Draft model loaded successfully. Max speculative tokens: {shared.args.draft_max}")
|
||||||
|
|
||||||
# Load main model last
|
# Load vision model component (ExLlamaV3 native)
|
||||||
model.load(**load_params)
|
vision_model = None
|
||||||
tokenizer = Tokenizer.from_config(config)
|
if "vision_config" in config.config_dict:
|
||||||
|
logger.info("Vision component detected in model config. Attempting to load...")
|
||||||
|
try:
|
||||||
|
vision_model = Model.from_config(config, component="vision")
|
||||||
|
vision_model.load(progressbar=True)
|
||||||
|
logger.info("Vision model loaded successfully.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Vision model loading failed (multimodal disabled): {e}")
|
||||||
|
else:
|
||||||
|
logger.info("No vision component in model config. Skipping multimodal setup.")
|
||||||
|
|
||||||
generator = Generator(
|
generator = Generator(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
@ -410,31 +385,11 @@ class Exllamav3Model:
|
||||||
else:
|
else:
|
||||||
max_new_tokens = state['max_new_tokens']
|
max_new_tokens = state['max_new_tokens']
|
||||||
|
|
||||||
# Use full EOS token list from config (may contain multiple IDs)
|
# Get stop conditions
|
||||||
stop_conditions = []
|
stop_conditions = []
|
||||||
if not state['ban_eos_token']:
|
if not state['ban_eos_token']:
|
||||||
for eos_id in self.config.eos_token_id_list:
|
if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
|
||||||
if eos_id is not None:
|
stop_conditions.append(self.tokenizer.eos_token_id)
|
||||||
stop_conditions.append(eos_id)
|
|
||||||
|
|
||||||
# Build filters for logit_bias (OpenAI API)
|
|
||||||
filters = []
|
|
||||||
logit_bias = state.get('logit_bias')
|
|
||||||
if logit_bias:
|
|
||||||
filters.append(LogitBiasFilter(self.tokenizer, logit_bias))
|
|
||||||
|
|
||||||
# Suppress EOS tokens via logit bias so they are never sampled
|
|
||||||
if state['ban_eos_token']:
|
|
||||||
eos_bias = {}
|
|
||||||
for eos_id in self.config.eos_token_id_list:
|
|
||||||
if eos_id is not None:
|
|
||||||
eos_bias[str(eos_id)] = float('-inf')
|
|
||||||
if eos_bias:
|
|
||||||
filters.append(LogitBiasFilter(self.tokenizer, eos_bias))
|
|
||||||
|
|
||||||
# Logprobs support (OpenAI API)
|
|
||||||
logprobs = state.get('logprobs', 0) or 0
|
|
||||||
return_top_tokens = logprobs if logprobs > 0 else 0
|
|
||||||
|
|
||||||
seed = state.get('seed', -1)
|
seed = state.get('seed', -1)
|
||||||
job = Job(
|
job = Job(
|
||||||
|
|
@ -445,15 +400,11 @@ class Exllamav3Model:
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
seed=seed if seed >= 0 else None,
|
seed=seed if seed >= 0 else None,
|
||||||
stop_conditions=stop_conditions if stop_conditions else None,
|
stop_conditions=stop_conditions if stop_conditions else None,
|
||||||
filters=filters if filters else None,
|
|
||||||
return_top_tokens=return_top_tokens,
|
|
||||||
return_probs=return_top_tokens > 0,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Stream generation
|
# Stream generation
|
||||||
response_text = ""
|
response_text = ""
|
||||||
stop_event = state.get('stop_event')
|
stop_event = state.get('stop_event')
|
||||||
self.last_completion_probabilities = []
|
|
||||||
|
|
||||||
result_queue = self.parallel_generator.submit(job)
|
result_queue = self.parallel_generator.submit(job)
|
||||||
try:
|
try:
|
||||||
|
|
@ -465,61 +416,14 @@ class Exllamav3Model:
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
if result is None or result.get("eos"):
|
if result is None or result.get("eos"):
|
||||||
# Capture logprobs from the final eos result too
|
|
||||||
if result is not None and return_top_tokens > 0:
|
|
||||||
self._capture_logprobs(result)
|
|
||||||
break
|
break
|
||||||
chunk = result.get("text", "")
|
chunk = result.get("text", "")
|
||||||
|
|
||||||
# Capture logprobs from streaming results
|
|
||||||
if return_top_tokens > 0:
|
|
||||||
self._capture_logprobs(result)
|
|
||||||
|
|
||||||
if chunk:
|
if chunk:
|
||||||
response_text += chunk
|
response_text += chunk
|
||||||
yield response_text
|
yield response_text
|
||||||
finally:
|
finally:
|
||||||
self.parallel_generator.cancel(job)
|
self.parallel_generator.cancel(job)
|
||||||
|
|
||||||
def _capture_logprobs(self, result):
|
|
||||||
"""Convert ExLlamav3 top-k token data to the shared logprobs format."""
|
|
||||||
top_k_tokens = result.get("top_k_tokens")
|
|
||||||
top_k_probs = result.get("top_k_probs")
|
|
||||||
if top_k_tokens is None or top_k_probs is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
id_to_piece = self.tokenizer.get_id_to_piece_list(True)
|
|
||||||
sampled_ids = result.get("token_ids") # (batch, seq_len) - actually sampled tokens
|
|
||||||
sampled_probs = result.get("token_probs") # (batch, seq_len) - their probabilities
|
|
||||||
|
|
||||||
def _piece(tid):
|
|
||||||
s = id_to_piece[tid] if tid < len(id_to_piece) else f"<{tid}>"
|
|
||||||
return s.replace('\u2581', ' ')
|
|
||||||
|
|
||||||
def _logprob(prob):
|
|
||||||
return math.log(prob) if prob > 0 else float("-inf")
|
|
||||||
|
|
||||||
# top_k_tokens shape: (batch, seq_len, k), top_k_probs same
|
|
||||||
for seq_idx in range(top_k_tokens.shape[1]):
|
|
||||||
entry = {"top_logprobs": []}
|
|
||||||
for k_idx in range(top_k_tokens.shape[2]):
|
|
||||||
token_id = top_k_tokens[0, seq_idx, k_idx].item()
|
|
||||||
prob = top_k_probs[0, seq_idx, k_idx].item()
|
|
||||||
entry["top_logprobs"].append({"token": _piece(token_id), "logprob": _logprob(prob)})
|
|
||||||
|
|
||||||
# Record the actually sampled token at the entry level so
|
|
||||||
# format_completion_logprobs uses it instead of top_logprobs[0]
|
|
||||||
# (they differ with non-greedy sampling).
|
|
||||||
if sampled_ids is not None:
|
|
||||||
sid = sampled_ids[0, seq_idx].item()
|
|
||||||
entry["token"] = _piece(sid)
|
|
||||||
if sampled_probs is not None:
|
|
||||||
entry["logprob"] = _logprob(sampled_probs[0, seq_idx].item())
|
|
||||||
else:
|
|
||||||
entry["logprob"] = None
|
|
||||||
|
|
||||||
self.last_completion_probabilities.append(entry)
|
|
||||||
|
|
||||||
def generate(self, prompt, state):
|
def generate(self, prompt, state):
|
||||||
output = ""
|
output = ""
|
||||||
for chunk in self.generate_with_streaming(prompt, state):
|
for chunk in self.generate_with_streaming(prompt, state):
|
||||||
|
|
@ -527,31 +431,42 @@ class Exllamav3Model:
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def get_prompt_logits(self, input_ids):
|
|
||||||
"""Return logits for all positions via a single no-cache forward pass.
|
|
||||||
|
|
||||||
Used by prompt logprobs computation. Returns (1, seq_len, vocab) on CPU in float32.
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
input_ids_tensor = input_ids if isinstance(input_ids, torch.Tensor) else torch.tensor(input_ids, dtype=torch.long)
|
|
||||||
input_ids_tensor = input_ids_tensor.view(1, -1).cpu()
|
|
||||||
with torch.no_grad():
|
|
||||||
return self.model.forward(
|
|
||||||
input_ids=input_ids_tensor,
|
|
||||||
params={"attn_mode": "flash_attn_nc"}
|
|
||||||
).cpu().float()
|
|
||||||
|
|
||||||
def get_logits(self, token_ids, **kwargs):
|
def get_logits(self, token_ids, **kwargs):
|
||||||
"""
|
"""
|
||||||
Process a batch of token_ids and return the logits for the last token.
|
Process a batch of token_ids and return the logits for the last token.
|
||||||
Uses flash_attn_nc (no cache) for correct results with recurrent models.
|
This will reset and overwrite the model's cache.
|
||||||
"""
|
"""
|
||||||
logits = self.model.forward(
|
# Initialize a single params dictionary that will be updated in-place
|
||||||
input_ids=token_ids,
|
params = {
|
||||||
params={"attn_mode": "flash_attn_nc"}
|
"cache": self.cache,
|
||||||
|
"reconstruct": False,
|
||||||
|
"attn_mode": "flash_attn",
|
||||||
|
"batch_shape": (1, self.max_tokens),
|
||||||
|
"past_len": 0
|
||||||
|
}
|
||||||
|
params.update(kwargs)
|
||||||
|
|
||||||
|
# Process prefix tokens to fill the cache and generate recurrent state
|
||||||
|
if token_ids.shape[-1] > 1:
|
||||||
|
prefix_ids = token_ids[:, :-1]
|
||||||
|
|
||||||
|
# This forward call updates the 'params' dict with the recurrent state
|
||||||
|
self.model.forward(
|
||||||
|
input_ids=prefix_ids,
|
||||||
|
params=params
|
||||||
)
|
)
|
||||||
|
|
||||||
return logits[:, -1:, :].float().cpu()
|
# Update past_len for the next call
|
||||||
|
params["past_len"] = prefix_ids.shape[-1]
|
||||||
|
|
||||||
|
# Process the last token, now using the state-filled 'params' dict
|
||||||
|
last_token_ids = token_ids[:, -1:]
|
||||||
|
logits = self.model.forward(
|
||||||
|
input_ids=last_token_ids,
|
||||||
|
params=params
|
||||||
|
)
|
||||||
|
|
||||||
|
return logits.float().cpu()
|
||||||
|
|
||||||
def encode(self, string, **kwargs):
|
def encode(self, string, **kwargs):
|
||||||
add_bos = kwargs.pop('add_bos', True)
|
add_bos = kwargs.pop('add_bos', True)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
|
|
@ -20,15 +21,13 @@ from modules.logging_colors import logger
|
||||||
try:
|
try:
|
||||||
import flash_attn
|
import flash_attn
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning('Failed to load flash-attention due to the following error:', exc_info=True)
|
logger.warning('Failed to load flash-attention due to the following error:\n')
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
class Exllamav3HF(PreTrainedModel, GenerationMixin):
|
class Exllamav3HF(PreTrainedModel, GenerationMixin):
|
||||||
def __init__(self, model_dir):
|
def __init__(self, model_dir):
|
||||||
hf_config = PretrainedConfig.from_pretrained(model_dir)
|
hf_config = PretrainedConfig.from_pretrained(model_dir)
|
||||||
# Ensure text_config is a proper object, not a dict (fixes qwen3_5_moe + transformers compat)
|
|
||||||
if isinstance(getattr(hf_config, 'text_config', None), dict):
|
|
||||||
hf_config.text_config = PretrainedConfig(**hf_config.text_config)
|
|
||||||
super().__init__(hf_config)
|
super().__init__(hf_config)
|
||||||
|
|
||||||
exl3_config = Config.from_directory(model_dir)
|
exl3_config = Config.from_directory(model_dir)
|
||||||
|
|
@ -202,11 +201,26 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
|
||||||
}
|
}
|
||||||
).to(input_ids.device).float()
|
).to(input_ids.device).float()
|
||||||
else:
|
else:
|
||||||
# Labels path: single pass without cache for correct logits
|
# When processing with labels, handle as a complete sequence
|
||||||
logits = self.ex_model.forward(
|
# Process in chunks if the number of tokens is large
|
||||||
input_ids=seq_tensor.view(1, -1),
|
tokens_to_process = seq_tensor
|
||||||
params={"attn_mode": "flash_attn_nc"}
|
all_logits = None
|
||||||
).float().cpu()
|
|
||||||
|
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
|
||||||
|
chunk = tokens_to_process[i:i + max_chunk_size]
|
||||||
|
chunk_logits = self.ex_model.forward(
|
||||||
|
input_ids=chunk.view(1, -1),
|
||||||
|
params={
|
||||||
|
"attn_mode": "flash_attn_nc",
|
||||||
|
}
|
||||||
|
).float()
|
||||||
|
|
||||||
|
if all_logits is None:
|
||||||
|
all_logits = chunk_logits
|
||||||
|
else:
|
||||||
|
all_logits = torch.cat([all_logits, chunk_logits], dim=1)
|
||||||
|
|
||||||
|
logits = all_logits
|
||||||
|
|
||||||
if is_negative:
|
if is_negative:
|
||||||
self.past_seq_negative = seq_tensor
|
self.past_seq_negative = seq_tensor
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,13 @@
|
||||||
import importlib
|
import importlib
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import sys
|
import sys
|
||||||
|
import traceback
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
|
|
@ -32,6 +35,7 @@ def load_extensions():
|
||||||
if name not in available_extensions:
|
if name not in available_extensions:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if name != 'api':
|
||||||
logger.info(f'Loading the extension "{name}"')
|
logger.info(f'Loading the extension "{name}"')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -73,7 +77,8 @@ def load_extensions():
|
||||||
raise
|
raise
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f'Failed to load the extension "{name}".')
|
logger.error(f'Failed to load the extension "{name}".')
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
# This iterator returns the extensions in the order specified in the command-line
|
# This iterator returns the extensions in the order specified in the command-line
|
||||||
|
|
@ -191,23 +196,24 @@ def _apply_custom_generate_reply():
|
||||||
|
|
||||||
|
|
||||||
def _apply_custom_css():
|
def _apply_custom_css():
|
||||||
return ''.join(
|
all_css = ''
|
||||||
getattr(extension, 'custom_css')()
|
for extension, _ in iterator():
|
||||||
for extension, _ in iterator()
|
if hasattr(extension, 'custom_css'):
|
||||||
if hasattr(extension, 'custom_css')
|
all_css += getattr(extension, 'custom_css')()
|
||||||
)
|
|
||||||
|
return all_css
|
||||||
|
|
||||||
|
|
||||||
def _apply_custom_js():
|
def _apply_custom_js():
|
||||||
return ''.join(
|
all_js = ''
|
||||||
getattr(extension, 'custom_js')()
|
for extension, _ in iterator():
|
||||||
for extension, _ in iterator()
|
if hasattr(extension, 'custom_js'):
|
||||||
if hasattr(extension, 'custom_js')
|
all_js += getattr(extension, 'custom_js')()
|
||||||
)
|
|
||||||
|
return all_js
|
||||||
|
|
||||||
|
|
||||||
def create_extensions_block():
|
def create_extensions_block():
|
||||||
import gradio as gr
|
|
||||||
to_display = []
|
to_display = []
|
||||||
for extension, name in iterator():
|
for extension, name in iterator():
|
||||||
if hasattr(extension, "ui") and not (hasattr(extension, 'params') and extension.params.get('is_tab', False)):
|
if hasattr(extension, "ui") and not (hasattr(extension, 'params') and extension.params.get('is_tab', False)):
|
||||||
|
|
@ -222,7 +228,6 @@ def create_extensions_block():
|
||||||
|
|
||||||
|
|
||||||
def create_extensions_tabs():
|
def create_extensions_tabs():
|
||||||
import gradio as gr
|
|
||||||
for extension, name in iterator():
|
for extension, name in iterator():
|
||||||
if hasattr(extension, "ui") and (hasattr(extension, 'params') and extension.params.get('is_tab', False)):
|
if hasattr(extension, "ui") and (hasattr(extension, 'params') and extension.params.get('is_tab', False)):
|
||||||
display_name = getattr(extension, 'params', {}).get('display_name', name)
|
display_name = getattr(extension, 'params', {}).get('display_name', name)
|
||||||
|
|
|
||||||
|
|
@ -10,17 +10,9 @@ import markdown
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.reasoning import extract_reasoning
|
|
||||||
from modules.sane_markdown_lists import SaneListExtension
|
from modules.sane_markdown_lists import SaneListExtension
|
||||||
from modules.utils import get_available_chat_styles
|
from modules.utils import get_available_chat_styles
|
||||||
|
|
||||||
# Pre-compiled regex for protecting markdown-sensitive characters inside LaTeX.
|
|
||||||
# Covers $$...$$, \[...\], \(...\), and inline $...$ (when content contains \\).
|
|
||||||
_LATEX_PATTERN = re.compile(
|
|
||||||
r'((?:^|[\r\n\s])\$\$[^`]*?\$\$)|\\\[(.*?)\\\]|\\\((.*?)\\\)|(?<!\$)\$(?!\$)([^\$\n]*\\\\[^\$\n]*?)\$(?!\$)',
|
|
||||||
re.DOTALL
|
|
||||||
)
|
|
||||||
|
|
||||||
# This is to store the paths to the thumbnails of the profile pictures
|
# This is to store the paths to the thumbnails of the profile pictures
|
||||||
image_cache = {}
|
image_cache = {}
|
||||||
|
|
||||||
|
|
@ -116,41 +108,69 @@ def replace_blockquote(m):
|
||||||
return m.group().replace('\n', '\n> ').replace('\\begin{blockquote}', '').replace('\\end{blockquote}', '')
|
return m.group().replace('\n', '\n> ').replace('\\begin{blockquote}', '').replace('\\end{blockquote}', '')
|
||||||
|
|
||||||
|
|
||||||
|
# Thinking block format definitions: (start_tag, end_tag, content_start_tag)
|
||||||
|
# Use None for start_tag to match from beginning (end-only formats should be listed last)
|
||||||
|
THINKING_FORMATS = [
|
||||||
|
('<think>', '</think>', None),
|
||||||
|
('<|channel|>analysis<|message|>', '<|end|>', '<|start|>assistant<|channel|>final<|message|>'),
|
||||||
|
('<seed:think>', '</seed:think>', None),
|
||||||
|
('<|think|>', '<|end|>', '<|content|>'), # Solar Open
|
||||||
|
('Thinking Process:', '</think>', None), # Qwen3.5 verbose thinking outside tags
|
||||||
|
(None, '</think>', None), # End-only variant (e.g., Qwen3-next)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def extract_thinking_block(string):
|
def extract_thinking_block(string):
|
||||||
"""Extract thinking blocks from the beginning of an HTML-escaped string."""
|
"""Extract thinking blocks from the beginning of a string."""
|
||||||
return extract_reasoning(string, html_escaped=True)
|
if not string:
|
||||||
|
return None, string
|
||||||
|
|
||||||
|
for start_tag, end_tag, content_tag in THINKING_FORMATS:
|
||||||
|
end_esc = html.escape(end_tag)
|
||||||
|
content_esc = html.escape(content_tag) if content_tag else None
|
||||||
|
|
||||||
|
if start_tag is None:
|
||||||
|
# End-only format: require end tag, start from beginning
|
||||||
|
end_pos = string.find(end_esc)
|
||||||
|
if end_pos == -1:
|
||||||
|
continue
|
||||||
|
thought_start = 0
|
||||||
|
else:
|
||||||
|
# Normal format: require start tag
|
||||||
|
start_esc = html.escape(start_tag)
|
||||||
|
start_pos = string.find(start_esc)
|
||||||
|
if start_pos == -1:
|
||||||
|
continue
|
||||||
|
thought_start = start_pos + len(start_esc)
|
||||||
|
end_pos = string.find(end_esc, thought_start)
|
||||||
|
|
||||||
|
if end_pos == -1:
|
||||||
|
# End tag missing - check if content tag can serve as fallback
|
||||||
|
if content_esc:
|
||||||
|
content_pos = string.find(content_esc, thought_start)
|
||||||
|
if content_pos != -1:
|
||||||
|
thought_end = content_pos
|
||||||
|
content_start = content_pos + len(content_esc)
|
||||||
|
else:
|
||||||
|
thought_end = len(string)
|
||||||
|
content_start = len(string)
|
||||||
|
else:
|
||||||
|
thought_end = len(string)
|
||||||
|
content_start = len(string)
|
||||||
|
else:
|
||||||
|
thought_end = end_pos
|
||||||
|
if content_esc:
|
||||||
|
content_pos = string.find(content_esc, end_pos)
|
||||||
|
content_start = content_pos + len(content_esc) if content_pos != -1 else end_pos + len(end_esc)
|
||||||
|
else:
|
||||||
|
content_start = end_pos + len(end_esc)
|
||||||
|
|
||||||
|
return string[thought_start:thought_end], string[content_start:]
|
||||||
|
|
||||||
|
return None, string
|
||||||
|
|
||||||
|
|
||||||
|
def build_thinking_block(thinking_content, message_id, has_remaining_content):
|
||||||
def build_tool_call_block(header, body, message_id, index):
|
|
||||||
"""Build HTML for a tool call accordion block."""
|
|
||||||
block_id = f"tool-call-{message_id}-{index}"
|
|
||||||
|
|
||||||
if body == '...':
|
|
||||||
# Pending placeholder — no expandable body, just title with ellipsis
|
|
||||||
return f'''
|
|
||||||
<details class="thinking-block" data-block-id="{block_id}">
|
|
||||||
<summary class="thinking-header">
|
|
||||||
{tool_svg_small}
|
|
||||||
<span class="thinking-title">{html.escape(header)} ...</span>
|
|
||||||
</summary>
|
|
||||||
</details>
|
|
||||||
'''
|
|
||||||
|
|
||||||
# Build a plain <pre> directly to avoid highlight.js auto-detection
|
|
||||||
escaped_body = html.escape(body)
|
|
||||||
return f'''
|
|
||||||
<details class="thinking-block" data-block-id="{block_id}">
|
|
||||||
<summary class="thinking-header">
|
|
||||||
{tool_svg_small}
|
|
||||||
<span class="thinking-title">{html.escape(header)}</span>
|
|
||||||
</summary>
|
|
||||||
<div class="thinking-content pretty_scrollbar"><pre><code class="nohighlight">{escaped_body}</code></pre></div>
|
|
||||||
</details>
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
def build_thinking_block(thinking_content, message_id, has_remaining_content, thinking_index=0):
|
|
||||||
"""Build HTML for a thinking block."""
|
"""Build HTML for a thinking block."""
|
||||||
if thinking_content is None:
|
if thinking_content is None:
|
||||||
return None
|
return None
|
||||||
|
|
@ -159,7 +179,7 @@ def build_thinking_block(thinking_content, message_id, has_remaining_content, th
|
||||||
thinking_html = process_markdown_content(thinking_content)
|
thinking_html = process_markdown_content(thinking_content)
|
||||||
|
|
||||||
# Generate unique ID for the thinking block
|
# Generate unique ID for the thinking block
|
||||||
block_id = f"thinking-{message_id}-{thinking_index}"
|
block_id = f"thinking-{message_id}-0"
|
||||||
|
|
||||||
# Check if thinking is complete or still in progress
|
# Check if thinking is complete or still in progress
|
||||||
is_streaming = not has_remaining_content
|
is_streaming = not has_remaining_content
|
||||||
|
|
@ -192,29 +212,28 @@ def process_markdown_content(string):
|
||||||
if not string:
|
if not string:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Define unique placeholders for LaTeX characters that conflict with markdown
|
# Define unique placeholders for LaTeX asterisks and underscores
|
||||||
LATEX_ASTERISK_PLACEHOLDER = "LATEXASTERISKPLACEHOLDER"
|
LATEX_ASTERISK_PLACEHOLDER = "LATEXASTERISKPLACEHOLDER"
|
||||||
LATEX_UNDERSCORE_PLACEHOLDER = "LATEXUNDERSCOREPLACEHOLDER"
|
LATEX_UNDERSCORE_PLACEHOLDER = "LATEXUNDERSCOREPLACEHOLDER"
|
||||||
LATEX_PIPE_PLACEHOLDER = "LATEXPIPEPLACEHOLDER"
|
|
||||||
|
|
||||||
def protect_latex_content(content):
|
|
||||||
"""Protect markdown-sensitive characters inside LaTeX."""
|
|
||||||
content = content.replace('*', LATEX_ASTERISK_PLACEHOLDER)
|
|
||||||
content = content.replace('_', LATEX_UNDERSCORE_PLACEHOLDER)
|
|
||||||
content = content.replace('|', LATEX_PIPE_PLACEHOLDER)
|
|
||||||
return content
|
|
||||||
|
|
||||||
def protect_asterisks_underscores_in_latex(match):
|
def protect_asterisks_underscores_in_latex(match):
|
||||||
"""A replacer function for re.sub to protect markdown-sensitive characters in multiple LaTeX formats."""
|
"""A replacer function for re.sub to protect asterisks and underscores in multiple LaTeX formats."""
|
||||||
# Check which delimiter group was captured
|
# Check which delimiter group was captured
|
||||||
if match.group(1) is not None: # Content from $$...$$
|
if match.group(1) is not None: # Content from $$...$$
|
||||||
return protect_latex_content(match.group(1))
|
content = match.group(1)
|
||||||
|
modified_content = content.replace('*', LATEX_ASTERISK_PLACEHOLDER)
|
||||||
|
modified_content = modified_content.replace('_', LATEX_UNDERSCORE_PLACEHOLDER)
|
||||||
|
return f'{modified_content}'
|
||||||
elif match.group(2) is not None: # Content from \[...\]
|
elif match.group(2) is not None: # Content from \[...\]
|
||||||
return f'\\[{protect_latex_content(match.group(2))}\\]'
|
content = match.group(2)
|
||||||
|
modified_content = content.replace('*', LATEX_ASTERISK_PLACEHOLDER)
|
||||||
|
modified_content = modified_content.replace('_', LATEX_UNDERSCORE_PLACEHOLDER)
|
||||||
|
return f'\\[{modified_content}\\]'
|
||||||
elif match.group(3) is not None: # Content from \(...\)
|
elif match.group(3) is not None: # Content from \(...\)
|
||||||
return f'\\({protect_latex_content(match.group(3))}\\)'
|
content = match.group(3)
|
||||||
elif match.group(4) is not None: # Content from $...$
|
modified_content = content.replace('*', LATEX_ASTERISK_PLACEHOLDER)
|
||||||
return f'${protect_latex_content(match.group(4).strip())}$'
|
modified_content = modified_content.replace('_', LATEX_UNDERSCORE_PLACEHOLDER)
|
||||||
|
return f'\\({modified_content}\\)'
|
||||||
|
|
||||||
return match.group(0) # Fallback
|
return match.group(0) # Fallback
|
||||||
|
|
||||||
|
|
@ -248,7 +267,9 @@ def process_markdown_content(string):
|
||||||
string = re.sub(r"(.)```", r"\1\n```", string)
|
string = re.sub(r"(.)```", r"\1\n```", string)
|
||||||
|
|
||||||
# Protect asterisks and underscores within all LaTeX blocks before markdown conversion
|
# Protect asterisks and underscores within all LaTeX blocks before markdown conversion
|
||||||
string = _LATEX_PATTERN.sub(protect_asterisks_underscores_in_latex, string)
|
latex_pattern = re.compile(r'((?:^|[\r\n\s])\$\$[^`]*?\$\$)|\\\[(.*?)\\\]|\\\((.*?)\\\)',
|
||||||
|
re.DOTALL)
|
||||||
|
string = latex_pattern.sub(protect_asterisks_underscores_in_latex, string)
|
||||||
|
|
||||||
result = ''
|
result = ''
|
||||||
is_code = False
|
is_code = False
|
||||||
|
|
@ -312,7 +333,6 @@ def process_markdown_content(string):
|
||||||
# Restore the LaTeX asterisks and underscores after markdown conversion
|
# Restore the LaTeX asterisks and underscores after markdown conversion
|
||||||
html_output = html_output.replace(LATEX_ASTERISK_PLACEHOLDER, '*')
|
html_output = html_output.replace(LATEX_ASTERISK_PLACEHOLDER, '*')
|
||||||
html_output = html_output.replace(LATEX_UNDERSCORE_PLACEHOLDER, '_')
|
html_output = html_output.replace(LATEX_UNDERSCORE_PLACEHOLDER, '_')
|
||||||
html_output = html_output.replace(LATEX_PIPE_PLACEHOLDER, '|')
|
|
||||||
|
|
||||||
# Remove extra newlines before </code>
|
# Remove extra newlines before </code>
|
||||||
html_output = re.sub(r'\s*</code>', '</code>', html_output)
|
html_output = re.sub(r'\s*</code>', '</code>', html_output)
|
||||||
|
|
@ -324,9 +344,6 @@ def process_markdown_content(string):
|
||||||
# Unescape backslashes
|
# Unescape backslashes
|
||||||
html_output = html_output.replace('\\\\', '\\')
|
html_output = html_output.replace('\\\\', '\\')
|
||||||
|
|
||||||
# Wrap tables in a scrollable div
|
|
||||||
html_output = html_output.replace('<table>', '<div class="table-wrapper pretty_scrollbar"><table>').replace('</table>', '</table></div>')
|
|
||||||
|
|
||||||
return html_output
|
return html_output
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -343,67 +360,25 @@ def convert_to_markdown(string, message_id=None):
|
||||||
if message_id is None:
|
if message_id is None:
|
||||||
message_id = "unknown"
|
message_id = "unknown"
|
||||||
|
|
||||||
# Find tool call blocks by position, then process the text segments
|
# Extract different components from the string
|
||||||
# between them using extract_thinking_block (which supports all
|
|
||||||
# THINKING_FORMATS, including end-only variants like Qwen's).
|
|
||||||
tool_call_pattern = re.compile(r'<tool_call>(.*?)\n(.*?)\n</tool_call>', re.DOTALL)
|
|
||||||
tool_calls = list(tool_call_pattern.finditer(string))
|
|
||||||
|
|
||||||
if not tool_calls:
|
|
||||||
# No tool calls — use original single-pass extraction
|
|
||||||
thinking_content, remaining_content = extract_thinking_block(string)
|
thinking_content, remaining_content = extract_thinking_block(string)
|
||||||
|
|
||||||
|
# Build individual HTML blocks
|
||||||
blocks = []
|
blocks = []
|
||||||
|
|
||||||
|
# Add thinking block if present
|
||||||
thinking_html = build_thinking_block(thinking_content, message_id, bool(remaining_content))
|
thinking_html = build_thinking_block(thinking_content, message_id, bool(remaining_content))
|
||||||
if thinking_html:
|
if thinking_html:
|
||||||
blocks.append(thinking_html)
|
blocks.append(thinking_html)
|
||||||
|
|
||||||
|
# Add main content block
|
||||||
main_html = build_main_content_block(remaining_content)
|
main_html = build_main_content_block(remaining_content)
|
||||||
if main_html:
|
if main_html:
|
||||||
blocks.append(main_html)
|
blocks.append(main_html)
|
||||||
|
|
||||||
|
# Assemble all blocks into final HTML
|
||||||
return ''.join(blocks)
|
return ''.join(blocks)
|
||||||
|
|
||||||
# Split string into text segments around tool_call blocks and
|
|
||||||
# run extract_thinking_block on each segment for full format support.
|
|
||||||
html_parts = []
|
|
||||||
last_end = 0
|
|
||||||
tool_idx = 0
|
|
||||||
think_idx = 0
|
|
||||||
|
|
||||||
def process_text_segment(text, is_last_segment):
|
|
||||||
"""Process a text segment between tool_call blocks for thinking content."""
|
|
||||||
nonlocal think_idx
|
|
||||||
if not text.strip():
|
|
||||||
return
|
|
||||||
|
|
||||||
while text.strip():
|
|
||||||
thinking_content, remaining = extract_thinking_block(text)
|
|
||||||
if thinking_content is None:
|
|
||||||
break
|
|
||||||
has_remaining = bool(remaining.strip()) or not is_last_segment
|
|
||||||
html_parts.append(build_thinking_block(thinking_content, message_id, has_remaining, think_idx))
|
|
||||||
think_idx += 1
|
|
||||||
text = remaining
|
|
||||||
|
|
||||||
if text.strip():
|
|
||||||
html_parts.append(process_markdown_content(text))
|
|
||||||
|
|
||||||
for tc in tool_calls:
|
|
||||||
# Process text before this tool_call
|
|
||||||
process_text_segment(string[last_end:tc.start()], is_last_segment=False)
|
|
||||||
|
|
||||||
# Add tool call accordion
|
|
||||||
header = tc.group(1).strip()
|
|
||||||
body = tc.group(2).strip()
|
|
||||||
html_parts.append(build_tool_call_block(header, body, message_id, tool_idx))
|
|
||||||
tool_idx += 1
|
|
||||||
last_end = tc.end()
|
|
||||||
|
|
||||||
# Process text after the last tool_call
|
|
||||||
process_text_segment(string[last_end:], is_last_segment=True)
|
|
||||||
|
|
||||||
return ''.join(html_parts)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_markdown_wrapped(string, message_id=None, use_cache=True):
|
def convert_to_markdown_wrapped(string, message_id=None, use_cache=True):
|
||||||
'''
|
'''
|
||||||
|
|
@ -460,7 +435,6 @@ branch_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24
|
||||||
edit_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="tabler-icon tabler-icon-pencil"><path d="M4 20h4l10.5 -10.5a2.828 2.828 0 1 0 -4 -4l-10.5 10.5v4"></path><path d="M13.5 6.5l4 4"></path></svg>'''
|
edit_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="tabler-icon tabler-icon-pencil"><path d="M4 20h4l10.5 -10.5a2.828 2.828 0 1 0 -4 -4l-10.5 10.5v4"></path><path d="M13.5 6.5l4 4"></path></svg>'''
|
||||||
info_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="thinking-icon tabler-icon tabler-icon-info-circle"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M12 2a10 10 0 0 1 0 20a10 10 0 0 1 0 -20z" /><path d="M12 16v-4" /><path d="M12 8h.01" /></svg>'''
|
info_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="thinking-icon tabler-icon tabler-icon-info-circle"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M12 2a10 10 0 0 1 0 20a10 10 0 0 1 0 -20z" /><path d="M12 16v-4" /><path d="M12 8h.01" /></svg>'''
|
||||||
info_svg_small = '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="thinking-icon tabler-icon tabler-icon-info-circle"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M12 2a10 10 0 0 1 0 20a10 10 0 0 1 0 -20z" /><path d="M12 16v-4" /><path d="M12 8h.01" /></svg>'''
|
info_svg_small = '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="thinking-icon tabler-icon tabler-icon-info-circle"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M12 2a10 10 0 0 1 0 20a10 10 0 0 1 0 -20z" /><path d="M12 16v-4" /><path d="M12 8h.01" /></svg>'''
|
||||||
tool_svg_small = '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="thinking-icon tabler-icon tabler-icon-tool"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M7 10h3v-3l-3.5 -3.5a6 6 0 0 1 8 8l6 6a2 2 0 0 1 -3 3l-6 -6a6 6 0 0 1 -8 -8l3.5 3.5" /></svg>'''
|
|
||||||
attachment_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M21.44 11.05l-9.19 9.19a6 6 0 0 1-8.48-8.48l9.19-9.19a4 4 0 0 1 5.66 5.66l-9.2 9.19a2 2 0 0 1-2.83-2.83l8.49-8.48"></path></svg>'''
|
attachment_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M21.44 11.05l-9.19 9.19a6 6 0 0 1-8.48-8.48l9.19-9.19a4 4 0 0 1 5.66 5.66l-9.2 9.19a2 2 0 0 1-2.83-2.83l8.49-8.48"></path></svg>'''
|
||||||
|
|
||||||
copy_button = f'<button class="footer-button footer-copy-button" title="Copy" onclick="copyToClipboard(this)">{copy_svg}</button>'
|
copy_button = f'<button class="footer-button footer-copy-button" title="Copy" onclick="copyToClipboard(this)">{copy_svg}</button>'
|
||||||
|
|
|
||||||
|
|
@ -10,49 +10,72 @@ def get_quantization_config(quant_method):
|
||||||
Get the appropriate quantization config based on the selected method.
|
Get the appropriate quantization config based on the selected method.
|
||||||
Applies quantization to both the transformer and the text_encoder.
|
Applies quantization to both the transformer and the text_encoder.
|
||||||
"""
|
"""
|
||||||
if quant_method == 'none' or not quant_method:
|
|
||||||
return None
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
# Import BitsAndBytesConfig from BOTH libraries to be safe
|
||||||
from diffusers import BitsAndBytesConfig as DiffusersBnBConfig
|
from diffusers import BitsAndBytesConfig as DiffusersBnBConfig
|
||||||
from diffusers import TorchAoConfig
|
from diffusers import TorchAoConfig
|
||||||
from diffusers.quantizers import PipelineQuantizationConfig
|
from diffusers.quantizers import PipelineQuantizationConfig
|
||||||
from transformers import BitsAndBytesConfig as TransformersBnBConfig
|
from transformers import BitsAndBytesConfig as TransformersBnBConfig
|
||||||
|
|
||||||
torchao_methods = {
|
if quant_method == 'none' or not quant_method:
|
||||||
'torchao-int8wo': 'int8wo',
|
return None
|
||||||
'torchao-fp4': 'fp4_e2m1',
|
|
||||||
'torchao-float8wo': 'float8wo',
|
|
||||||
}
|
|
||||||
|
|
||||||
if quant_method == 'bnb-8bit':
|
# Bitsandbytes 8-bit quantization
|
||||||
|
elif quant_method == 'bnb-8bit':
|
||||||
return PipelineQuantizationConfig(
|
return PipelineQuantizationConfig(
|
||||||
quant_mapping={
|
quant_mapping={
|
||||||
"transformer": DiffusersBnBConfig(load_in_8bit=True),
|
"transformer": DiffusersBnBConfig(
|
||||||
"text_encoder": TransformersBnBConfig(load_in_8bit=True)
|
load_in_8bit=True
|
||||||
|
),
|
||||||
|
"text_encoder": TransformersBnBConfig(
|
||||||
|
load_in_8bit=True
|
||||||
|
)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Bitsandbytes 4-bit quantization
|
||||||
elif quant_method == 'bnb-4bit':
|
elif quant_method == 'bnb-4bit':
|
||||||
bnb_4bit_kwargs = dict(
|
return PipelineQuantizationConfig(
|
||||||
|
quant_mapping={
|
||||||
|
"transformer": DiffusersBnBConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||||
|
bnb_4bit_use_double_quant=True
|
||||||
|
),
|
||||||
|
"text_encoder": TransformersBnBConfig(
|
||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
bnb_4bit_quant_type="nf4",
|
bnb_4bit_quant_type="nf4",
|
||||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||||
bnb_4bit_use_double_quant=True
|
bnb_4bit_use_double_quant=True
|
||||||
)
|
)
|
||||||
return PipelineQuantizationConfig(
|
|
||||||
quant_mapping={
|
|
||||||
"transformer": DiffusersBnBConfig(**bnb_4bit_kwargs),
|
|
||||||
"text_encoder": TransformersBnBConfig(**bnb_4bit_kwargs)
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
elif quant_method in torchao_methods:
|
# torchao int8 weight-only
|
||||||
ao_type = torchao_methods[quant_method]
|
elif quant_method == 'torchao-int8wo':
|
||||||
return PipelineQuantizationConfig(
|
return PipelineQuantizationConfig(
|
||||||
quant_mapping={
|
quant_mapping={
|
||||||
"transformer": TorchAoConfig(ao_type),
|
"transformer": TorchAoConfig("int8wo"),
|
||||||
"text_encoder": TorchAoConfig(ao_type)
|
"text_encoder": TorchAoConfig("int8wo")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# torchao fp4 (e2m1)
|
||||||
|
elif quant_method == 'torchao-fp4':
|
||||||
|
return PipelineQuantizationConfig(
|
||||||
|
quant_mapping={
|
||||||
|
"transformer": TorchAoConfig("fp4_e2m1"),
|
||||||
|
"text_encoder": TorchAoConfig("fp4_e2m1")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# torchao float8 weight-only
|
||||||
|
elif quant_method == 'torchao-float8wo':
|
||||||
|
return PipelineQuantizationConfig(
|
||||||
|
quant_mapping={
|
||||||
|
"transformer": TorchAoConfig("float8wo"),
|
||||||
|
"text_encoder": TorchAoConfig("float8wo")
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -129,7 +152,7 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
|
||||||
|
|
||||||
modules = ["transformer", "unet"]
|
modules = ["transformer", "unet"]
|
||||||
|
|
||||||
# Set attention backend (diffusers defaults to native/SDPA)
|
# Set attention backend
|
||||||
if attn_backend == 'flash_attention_2':
|
if attn_backend == 'flash_attention_2':
|
||||||
for name in modules:
|
for name in modules:
|
||||||
mod = getattr(pipe, name, None)
|
mod = getattr(pipe, name, None)
|
||||||
|
|
|
||||||
|
|
@ -77,18 +77,7 @@ def process_message_content(content: Any) -> Tuple[str, List[Image.Image]]:
|
||||||
# Support external URLs
|
# Support external URLs
|
||||||
try:
|
try:
|
||||||
import requests
|
import requests
|
||||||
from urllib.parse import urljoin
|
response = requests.get(image_url, timeout=10)
|
||||||
from modules.web_search import _validate_url
|
|
||||||
_validate_url(image_url)
|
|
||||||
url = image_url
|
|
||||||
for _ in range(5):
|
|
||||||
response = requests.get(url, timeout=10, allow_redirects=False)
|
|
||||||
if response.is_redirect and 'Location' in response.headers:
|
|
||||||
url = urljoin(url, response.headers['Location'])
|
|
||||||
_validate_url(url)
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
image_data = response.content
|
image_data = response.content
|
||||||
image = Image.open(io.BytesIO(image_data))
|
image = Image.open(io.BytesIO(image_data))
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import pprint
|
import pprint
|
||||||
import shlex
|
|
||||||
import re
|
import re
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
@ -11,6 +10,7 @@ import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
|
import llama_cpp_binaries
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
|
@ -36,7 +36,6 @@ class LlamaServer:
|
||||||
self.process = None
|
self.process = None
|
||||||
self.session = requests.Session()
|
self.session = requests.Session()
|
||||||
self.vocabulary_size = None
|
self.vocabulary_size = None
|
||||||
self.n_ctx = None
|
|
||||||
self.bos_token = "<s>"
|
self.bos_token = "<s>"
|
||||||
self.last_prompt_token_count = 0
|
self.last_prompt_token_count = 0
|
||||||
|
|
||||||
|
|
@ -130,24 +129,13 @@ class LlamaServer:
|
||||||
# places it at the end of the chain regardless of position, so we
|
# places it at the end of the chain regardless of position, so we
|
||||||
# activate it based on the parameter value rather than sampler order.
|
# activate it based on the parameter value rather than sampler order.
|
||||||
if state.get("adaptive_target", 0) > 0:
|
if state.get("adaptive_target", 0) > 0:
|
||||||
filtered_samplers.append("adaptive_p")
|
filtered_samplers.append("adaptive-p")
|
||||||
|
|
||||||
payload["samplers"] = filtered_samplers
|
payload["samplers"] = filtered_samplers
|
||||||
|
|
||||||
logit_bias = []
|
|
||||||
if state['custom_token_bans']:
|
if state['custom_token_bans']:
|
||||||
logit_bias.extend([[int(token_id.strip()), False] for token_id in state['custom_token_bans'].split(',') if token_id.strip()])
|
to_ban = [[int(token_id), False] for token_id in state['custom_token_bans'].split(',')]
|
||||||
|
payload["logit_bias"] = to_ban
|
||||||
if state.get('logit_bias'):
|
|
||||||
for token_id_str, bias in state['logit_bias'].items():
|
|
||||||
logit_bias.append([int(token_id_str), bias])
|
|
||||||
|
|
||||||
if logit_bias:
|
|
||||||
payload["logit_bias"] = logit_bias
|
|
||||||
|
|
||||||
n_probs = state.get('logprobs', 0)
|
|
||||||
if n_probs and n_probs > 0:
|
|
||||||
payload["n_probs"] = n_probs
|
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
|
@ -227,7 +215,6 @@ class LlamaServer:
|
||||||
response.raise_for_status() # Raise an exception for HTTP errors
|
response.raise_for_status() # Raise an exception for HTTP errors
|
||||||
|
|
||||||
full_text = ""
|
full_text = ""
|
||||||
self.last_completion_probabilities = []
|
|
||||||
|
|
||||||
# Process the streaming response
|
# Process the streaming response
|
||||||
stop_event = state.get('stop_event')
|
stop_event = state.get('stop_event')
|
||||||
|
|
@ -253,10 +240,6 @@ class LlamaServer:
|
||||||
full_text += data['content']
|
full_text += data['content']
|
||||||
yield full_text
|
yield full_text
|
||||||
|
|
||||||
# Capture logprobs if present
|
|
||||||
if 'completion_probabilities' in data:
|
|
||||||
self.last_completion_probabilities.extend(data['completion_probabilities'])
|
|
||||||
|
|
||||||
# Check if generation is complete
|
# Check if generation is complete
|
||||||
if data.get('stop', False):
|
if data.get('stop', False):
|
||||||
break
|
break
|
||||||
|
|
@ -310,45 +293,8 @@ class LlamaServer:
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}")
|
raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}")
|
||||||
|
|
||||||
def get_prompt_logprob_entries(self, token_ids, n_probs=5, prompt=""):
|
|
||||||
"""Get logprob entries for prompt tokens via a single n_predict=0 request.
|
|
||||||
|
|
||||||
Requires llama.cpp server with prompt_logprobs support.
|
|
||||||
Returns entries in the standard format for format_completion_logprobs().
|
|
||||||
"""
|
|
||||||
token_ids_list = token_ids.tolist() if hasattr(token_ids, 'tolist') else list(token_ids)
|
|
||||||
|
|
||||||
url = f"http://127.0.0.1:{self.port}/completion"
|
|
||||||
payload = {
|
|
||||||
"prompt": token_ids_list,
|
|
||||||
"n_predict": 0,
|
|
||||||
"n_probs": n_probs,
|
|
||||||
"prompt_logprobs": True,
|
|
||||||
"stream": False,
|
|
||||||
"cache_prompt": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
response = self.session.post(url, json=payload)
|
|
||||||
result = response.json()
|
|
||||||
|
|
||||||
prompt_probs = result.get("prompt_probabilities", [])
|
|
||||||
if not prompt_probs:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Null first token (no conditioning context); use empty string for BOS
|
|
||||||
# or tokens that don't appear at the start of the prompt text.
|
|
||||||
first_token_str = self.decode([token_ids_list[0]])
|
|
||||||
if self.bos_token and first_token_str == self.bos_token:
|
|
||||||
first_token_str = ""
|
|
||||||
elif not prompt.startswith(first_token_str):
|
|
||||||
first_token_str = ""
|
|
||||||
|
|
||||||
entries = [{"token": first_token_str, "null_logprob": True}]
|
|
||||||
entries.extend(prompt_probs)
|
|
||||||
return entries
|
|
||||||
|
|
||||||
def _get_vocabulary_size(self):
|
def _get_vocabulary_size(self):
|
||||||
"""Get and store the model's vocabulary size."""
|
"""Get and store the model's maximum context length."""
|
||||||
url = f"http://127.0.0.1:{self.port}/v1/models"
|
url = f"http://127.0.0.1:{self.port}/v1/models"
|
||||||
response = self.session.get(url).json()
|
response = self.session.get(url).json()
|
||||||
|
|
||||||
|
|
@ -358,22 +304,16 @@ class LlamaServer:
|
||||||
self.vocabulary_size = model_info["meta"]["n_vocab"]
|
self.vocabulary_size = model_info["meta"]["n_vocab"]
|
||||||
|
|
||||||
def _get_bos_token(self):
|
def _get_bos_token(self):
|
||||||
"""Get and store the model's BOS token and context size."""
|
"""Get and store the model's BOS token."""
|
||||||
url = f"http://127.0.0.1:{self.port}/props"
|
url = f"http://127.0.0.1:{self.port}/props"
|
||||||
response = self.session.get(url).json()
|
response = self.session.get(url).json()
|
||||||
if "bos_token" in response:
|
if "bos_token" in response:
|
||||||
self.bos_token = response["bos_token"]
|
self.bos_token = response["bos_token"]
|
||||||
|
|
||||||
# Get actual n_ctx from the server (important when --fit auto-selects it)
|
|
||||||
n_ctx = response.get("default_generation_settings", {}).get("n_ctx")
|
|
||||||
if n_ctx:
|
|
||||||
self.n_ctx = n_ctx
|
|
||||||
|
|
||||||
def _is_port_available(self, port):
|
def _is_port_available(self, port):
|
||||||
"""Check if a port is available for use."""
|
"""Check if a port is available for use."""
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
try:
|
try:
|
||||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
||||||
s.bind(('', port))
|
s.bind(('', port))
|
||||||
return True
|
return True
|
||||||
except OSError:
|
except OSError:
|
||||||
|
|
@ -394,15 +334,6 @@ class LlamaServer:
|
||||||
"""Start the llama.cpp server and wait until it's ready."""
|
"""Start the llama.cpp server and wait until it's ready."""
|
||||||
# Determine the server path
|
# Determine the server path
|
||||||
if self.server_path is None:
|
if self.server_path is None:
|
||||||
if shared.args.ik:
|
|
||||||
try:
|
|
||||||
import ik_llama_cpp_binaries
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("--ik requires the ik_llama_cpp_binaries package. Install it with: pip install <ik_llama_cpp_binaries wheel URL>")
|
|
||||||
|
|
||||||
self.server_path = ik_llama_cpp_binaries.get_binary_path()
|
|
||||||
else:
|
|
||||||
import llama_cpp_binaries
|
|
||||||
self.server_path = llama_cpp_binaries.get_binary_path()
|
self.server_path = llama_cpp_binaries.get_binary_path()
|
||||||
|
|
||||||
# Build the command
|
# Build the command
|
||||||
|
|
@ -418,14 +349,11 @@ class LlamaServer:
|
||||||
|
|
||||||
if shared.args.ctx_size > 0:
|
if shared.args.ctx_size > 0:
|
||||||
cmd += ["--ctx-size", str(shared.args.ctx_size)]
|
cmd += ["--ctx-size", str(shared.args.ctx_size)]
|
||||||
elif shared.args.gpu_layers >= 0:
|
|
||||||
cmd += ["--ctx-size", "8192"]
|
|
||||||
|
|
||||||
if shared.args.gpu_layers >= 0:
|
if shared.args.gpu_layers >= 0:
|
||||||
cmd += ["--gpu-layers", str(shared.args.gpu_layers), "--fit", "off"]
|
cmd += ["--gpu-layers", str(shared.args.gpu_layers), "--fit", "off"]
|
||||||
else:
|
else:
|
||||||
cmd += ["--fit", "on"]
|
cmd += ["--fit", "on"]
|
||||||
cmd += ["--fit-ctx", "8192"]
|
|
||||||
if shared.args.fit_target:
|
if shared.args.fit_target:
|
||||||
cmd += ["--fit-target", shared.args.fit_target]
|
cmd += ["--fit-target", shared.args.fit_target]
|
||||||
|
|
||||||
|
|
@ -451,6 +379,10 @@ class LlamaServer:
|
||||||
if shared.args.cache_type != "fp16" and shared.args.cache_type in llamacpp_valid_cache_types:
|
if shared.args.cache_type != "fp16" and shared.args.cache_type in llamacpp_valid_cache_types:
|
||||||
cmd += ["--cache-type-k", shared.args.cache_type, "--cache-type-v", shared.args.cache_type]
|
cmd += ["--cache-type-k", shared.args.cache_type, "--cache-type-v", shared.args.cache_type]
|
||||||
cache_type = shared.args.cache_type
|
cache_type = shared.args.cache_type
|
||||||
|
if shared.args.compress_pos_emb != 1:
|
||||||
|
cmd += ["--rope-freq-scale", str(1.0 / shared.args.compress_pos_emb)]
|
||||||
|
if shared.args.rope_freq_base > 0:
|
||||||
|
cmd += ["--rope-freq-base", str(shared.args.rope_freq_base)]
|
||||||
if shared.args.mmproj not in [None, 'None']:
|
if shared.args.mmproj not in [None, 'None']:
|
||||||
path = Path(shared.args.mmproj)
|
path = Path(shared.args.mmproj)
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
|
|
@ -493,33 +425,22 @@ class LlamaServer:
|
||||||
elif extra_flags.startswith("'") and extra_flags.endswith("'"):
|
elif extra_flags.startswith("'") and extra_flags.endswith("'"):
|
||||||
extra_flags = extra_flags[1:-1].strip()
|
extra_flags = extra_flags[1:-1].strip()
|
||||||
|
|
||||||
if extra_flags.startswith('-'):
|
|
||||||
# New literal format: "--jinja --rpc 1222,1222"
|
|
||||||
cmd += shlex.split(extra_flags)
|
|
||||||
else:
|
|
||||||
# Legacy format: "flag1=value1,flag2,flag3=value3"
|
|
||||||
long_form_only = {'rpc', 'fit', 'pos', 'ppl'}
|
|
||||||
|
|
||||||
for flag_item in extra_flags.split(','):
|
for flag_item in extra_flags.split(','):
|
||||||
flag_item = flag_item.strip()
|
flag_item = flag_item.strip()
|
||||||
if '=' in flag_item:
|
if '=' in flag_item:
|
||||||
flag, value = flag_item.split('=', 1)
|
flag, value = flag_item.split('=', 1)
|
||||||
flag = flag.strip()
|
flag = flag.strip()
|
||||||
value = value.strip()
|
value = value.strip()
|
||||||
if len(flag) <= 3 and flag not in long_form_only:
|
if len(flag) <= 3:
|
||||||
cmd += [f"-{flag}", value]
|
cmd += [f"-{flag}", value]
|
||||||
else:
|
else:
|
||||||
cmd += [f"--{flag}", value]
|
cmd += [f"--{flag}", value]
|
||||||
else:
|
else:
|
||||||
if len(flag_item) <= 3 and flag_item not in long_form_only:
|
if len(flag_item) <= 3:
|
||||||
cmd.append(f"-{flag_item}")
|
cmd.append(f"-{flag_item}")
|
||||||
else:
|
else:
|
||||||
cmd.append(f"--{flag_item}")
|
cmd.append(f"--{flag_item}")
|
||||||
|
|
||||||
# Patch flags for ik_llama.cpp compatibility
|
|
||||||
if shared.args.ik:
|
|
||||||
cmd = _patch_cmd_for_ik(cmd)
|
|
||||||
|
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
if os.name == 'posix':
|
if os.name == 'posix':
|
||||||
current_path = env.get('LD_LIBRARY_PATH', '')
|
current_path = env.get('LD_LIBRARY_PATH', '')
|
||||||
|
|
@ -534,7 +455,7 @@ class LlamaServer:
|
||||||
print()
|
print()
|
||||||
|
|
||||||
gpu_layers_str = "auto" if shared.args.gpu_layers < 0 else str(shared.args.gpu_layers)
|
gpu_layers_str = "auto" if shared.args.gpu_layers < 0 else str(shared.args.gpu_layers)
|
||||||
ctx_size_str = "auto" if shared.args.ctx_size == 0 and shared.args.gpu_layers < 0 else str(shared.args.ctx_size or 8192)
|
ctx_size_str = "auto" if shared.args.ctx_size == 0 else str(shared.args.ctx_size)
|
||||||
logger.info(f"Using gpu_layers={gpu_layers_str} | ctx_size={ctx_size_str} | cache_type={cache_type}")
|
logger.info(f"Using gpu_layers={gpu_layers_str} | ctx_size={ctx_size_str} | cache_type={cache_type}")
|
||||||
# Start the server with pipes for output
|
# Start the server with pipes for output
|
||||||
self.process = subprocess.Popen(
|
self.process = subprocess.Popen(
|
||||||
|
|
@ -550,8 +471,9 @@ class LlamaServer:
|
||||||
health_url = f"http://127.0.0.1:{self.port}/health"
|
health_url = f"http://127.0.0.1:{self.port}/health"
|
||||||
while True:
|
while True:
|
||||||
# Check if process is still alive
|
# Check if process is still alive
|
||||||
|
if self.process.poll() is not None:
|
||||||
|
# Process has terminated
|
||||||
exit_code = self.process.poll()
|
exit_code = self.process.poll()
|
||||||
if exit_code is not None:
|
|
||||||
raise RuntimeError(f"Server process terminated unexpectedly with exit code: {exit_code}")
|
raise RuntimeError(f"Server process terminated unexpectedly with exit code: {exit_code}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -657,49 +579,3 @@ def filter_stderr_with_progress(process_stderr):
|
||||||
process_stderr.close()
|
process_stderr.close()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _patch_cmd_for_ik(cmd):
|
|
||||||
"""
|
|
||||||
Rewrite upstream llama.cpp flags to ik_llama.cpp equivalents:
|
|
||||||
--no-webui → --webui none
|
|
||||||
--fit off → (removed)
|
|
||||||
--fit on / --fit-ctx → --fit (bare flag)
|
|
||||||
--fit-target → --fit-margin
|
|
||||||
--cache-reuse → (removed, unsupported)
|
|
||||||
--swa-full → (removed, unsupported)
|
|
||||||
"""
|
|
||||||
# Add Hadamard KV cache rotation when using quantized cache types.
|
|
||||||
# This significantly improves quantized cache quality (especially q4_0)
|
|
||||||
# and is a no-op for MLA models like DeepSeek.
|
|
||||||
if shared.args.cache_type in ("q8_0", "q4_0"):
|
|
||||||
cmd += ["-khad", "-vhad"]
|
|
||||||
|
|
||||||
patched = []
|
|
||||||
i = 0
|
|
||||||
while i < len(cmd):
|
|
||||||
arg = cmd[i]
|
|
||||||
|
|
||||||
if arg == "--no-webui":
|
|
||||||
patched += ["--webui", "none"]
|
|
||||||
elif arg == "--fit" and i + 1 < len(cmd) and cmd[i + 1] in ("on", "off"):
|
|
||||||
val = cmd[i + 1]
|
|
||||||
i += 1
|
|
||||||
if val == "on":
|
|
||||||
patched.append("--fit")
|
|
||||||
# "off" → drop entirely
|
|
||||||
elif arg == "--fit-ctx":
|
|
||||||
patched.append("--fit")
|
|
||||||
i += 1 # skip the value
|
|
||||||
elif arg == "--fit-target":
|
|
||||||
patched.append("--fit-margin")
|
|
||||||
elif arg == "--cache-reuse":
|
|
||||||
i += 1 # skip the value
|
|
||||||
elif arg == "--swa-full":
|
|
||||||
pass # bare flag, just drop it
|
|
||||||
else:
|
|
||||||
patched.append(arg)
|
|
||||||
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
return patched
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
import functools
|
import functools
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
loaders_and_params = OrderedDict({
|
loaders_and_params = OrderedDict({
|
||||||
'llama.cpp': [
|
'llama.cpp': [
|
||||||
'gpu_layers',
|
'gpu_layers',
|
||||||
|
|
@ -15,12 +17,13 @@ loaders_and_params = OrderedDict({
|
||||||
'tensor_split',
|
'tensor_split',
|
||||||
'extra_flags',
|
'extra_flags',
|
||||||
'streaming_llm',
|
'streaming_llm',
|
||||||
|
'rope_freq_base',
|
||||||
|
'compress_pos_emb',
|
||||||
'row_split',
|
'row_split',
|
||||||
'no_kv_offload',
|
'no_kv_offload',
|
||||||
'no_mmap',
|
'no_mmap',
|
||||||
'mlock',
|
'mlock',
|
||||||
'numa',
|
'numa',
|
||||||
'ik',
|
|
||||||
'parallel',
|
'parallel',
|
||||||
'model_draft',
|
'model_draft',
|
||||||
'draft_max',
|
'draft_max',
|
||||||
|
|
@ -40,6 +43,8 @@ loaders_and_params = OrderedDict({
|
||||||
'Transformers': [
|
'Transformers': [
|
||||||
'gpu_split',
|
'gpu_split',
|
||||||
'cpu_memory',
|
'cpu_memory',
|
||||||
|
'alpha_value',
|
||||||
|
'compress_pos_emb',
|
||||||
'compute_dtype',
|
'compute_dtype',
|
||||||
'quant_type',
|
'quant_type',
|
||||||
'load_in_8bit',
|
'load_in_8bit',
|
||||||
|
|
@ -66,6 +71,7 @@ loaders_and_params = OrderedDict({
|
||||||
'gpu_split',
|
'gpu_split',
|
||||||
'model_draft',
|
'model_draft',
|
||||||
'draft_max',
|
'draft_max',
|
||||||
|
'ctx_size_draft',
|
||||||
'speculative_decoding_accordion',
|
'speculative_decoding_accordion',
|
||||||
'enable_tp',
|
'enable_tp',
|
||||||
'tp_backend',
|
'tp_backend',
|
||||||
|
|
@ -202,7 +208,6 @@ loaders_samplers = {
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
'add_bos_token',
|
'add_bos_token',
|
||||||
'enable_thinking',
|
'enable_thinking',
|
||||||
'reasoning_effort',
|
|
||||||
'seed',
|
'seed',
|
||||||
'skip_special_tokens',
|
'skip_special_tokens',
|
||||||
},
|
},
|
||||||
|
|
@ -239,7 +244,6 @@ loaders_samplers = {
|
||||||
'reasoning_effort',
|
'reasoning_effort',
|
||||||
'seed',
|
'seed',
|
||||||
'sampler_priority',
|
'sampler_priority',
|
||||||
'custom_token_bans',
|
|
||||||
'dry_sequence_breakers',
|
'dry_sequence_breakers',
|
||||||
'grammar_string',
|
'grammar_string',
|
||||||
'grammar_file_row',
|
'grammar_file_row',
|
||||||
|
|
@ -273,7 +277,6 @@ def list_all_samplers():
|
||||||
|
|
||||||
|
|
||||||
def blacklist_samplers(loader, dynamic_temperature):
|
def blacklist_samplers(loader, dynamic_temperature):
|
||||||
import gradio as gr
|
|
||||||
all_samplers = list_all_samplers()
|
all_samplers = list_all_samplers()
|
||||||
output = []
|
output = []
|
||||||
|
|
||||||
|
|
@ -291,77 +294,15 @@ def blacklist_samplers(loader, dynamic_temperature):
|
||||||
|
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def get_all_params():
|
def get_all_params():
|
||||||
from modules import shared
|
|
||||||
all_params = set()
|
all_params = set()
|
||||||
for k in loaders_and_params:
|
for k in loaders_and_params:
|
||||||
for el in loaders_and_params[k]:
|
for el in loaders_and_params[k]:
|
||||||
all_params.add(el)
|
all_params.add(el)
|
||||||
|
|
||||||
if shared.args.portable:
|
|
||||||
all_params.discard('ik')
|
|
||||||
|
|
||||||
return sorted(all_params)
|
return sorted(all_params)
|
||||||
|
|
||||||
|
|
||||||
@functools.cache
|
|
||||||
def list_model_elements():
|
|
||||||
elements = [
|
|
||||||
'filter_by_loader',
|
|
||||||
'loader',
|
|
||||||
'cpu_memory',
|
|
||||||
'gpu_layers',
|
|
||||||
'fit_target',
|
|
||||||
'cpu_moe',
|
|
||||||
'threads',
|
|
||||||
'threads_batch',
|
|
||||||
'batch_size',
|
|
||||||
'ubatch_size',
|
|
||||||
'ctx_size',
|
|
||||||
'cache_type',
|
|
||||||
'tensor_split',
|
|
||||||
'extra_flags',
|
|
||||||
'streaming_llm',
|
|
||||||
'gpu_split',
|
|
||||||
'compute_dtype',
|
|
||||||
'quant_type',
|
|
||||||
'load_in_8bit',
|
|
||||||
'load_in_4bit',
|
|
||||||
'attn_implementation',
|
|
||||||
'cpu',
|
|
||||||
'disk',
|
|
||||||
'row_split',
|
|
||||||
'no_kv_offload',
|
|
||||||
'no_mmap',
|
|
||||||
'mlock',
|
|
||||||
'numa',
|
|
||||||
'parallel',
|
|
||||||
'use_double_quant',
|
|
||||||
'bf16',
|
|
||||||
'enable_tp',
|
|
||||||
'tp_backend',
|
|
||||||
'cfg_cache',
|
|
||||||
'no_use_fast',
|
|
||||||
'model_draft',
|
|
||||||
'draft_max',
|
|
||||||
'gpu_layers_draft',
|
|
||||||
'device_draft',
|
|
||||||
'ctx_size_draft',
|
|
||||||
'spec_type',
|
|
||||||
'spec_ngram_size_n',
|
|
||||||
'spec_ngram_size_m',
|
|
||||||
'spec_ngram_min_hits',
|
|
||||||
'mmproj',
|
|
||||||
]
|
|
||||||
|
|
||||||
from modules import shared
|
|
||||||
if not shared.args.portable:
|
|
||||||
elements.append('ik')
|
|
||||||
|
|
||||||
return elements
|
|
||||||
|
|
||||||
|
|
||||||
def make_loader_params_visible(loader):
|
def make_loader_params_visible(loader):
|
||||||
import gradio as gr
|
|
||||||
params = []
|
params = []
|
||||||
all_params = get_all_params()
|
all_params = get_all_params()
|
||||||
if loader in loaders_and_params:
|
if loader in loaders_and_params:
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from modules import models, shared
|
from modules import models, shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
from modules.models import load_model
|
||||||
from modules.text_generation import generate_reply
|
from modules.text_generation import generate_reply
|
||||||
from modules.utils import check_model_loaded
|
from modules.utils import check_model_loaded
|
||||||
|
|
||||||
|
|
@ -11,7 +13,8 @@ global_scores = None
|
||||||
|
|
||||||
|
|
||||||
def get_next_logits(*args, **kwargs):
|
def get_next_logits(*args, **kwargs):
|
||||||
models.load_model_if_idle_unloaded()
|
if shared.args.idle_timeout > 0 and shared.model is None and shared.model_name not in [None, 'None']:
|
||||||
|
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||||
|
|
||||||
needs_lock = not args[2] # use_samplers
|
needs_lock = not args[2] # use_samplers
|
||||||
if needs_lock:
|
if needs_lock:
|
||||||
|
|
@ -20,7 +23,7 @@ def get_next_logits(*args, **kwargs):
|
||||||
try:
|
try:
|
||||||
result = _get_next_logits(*args, **kwargs)
|
result = _get_next_logits(*args, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to get next logits")
|
traceback.print_exc()
|
||||||
result = None
|
result = None
|
||||||
|
|
||||||
if needs_lock:
|
if needs_lock:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
import sys
|
import sys
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
|
@ -8,15 +7,6 @@ from modules.models_settings import get_model_metadata
|
||||||
from modules.utils import resolve_model_path
|
from modules.utils import resolve_model_path
|
||||||
|
|
||||||
last_generation_time = time.time()
|
last_generation_time = time.time()
|
||||||
active_generation_count = 0
|
|
||||||
_generation_count_lock = threading.Lock()
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_if_idle_unloaded():
|
|
||||||
global last_generation_time
|
|
||||||
if shared.args.idle_timeout > 0 and shared.model is None and shared.model_name not in [None, 'None']:
|
|
||||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
|
||||||
last_generation_time = time.time()
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_name, loader=None):
|
def load_model(model_name, loader=None):
|
||||||
|
|
@ -48,9 +38,6 @@ def load_model(model_name, loader=None):
|
||||||
sampler_hijack.hijack_samplers()
|
sampler_hijack.hijack_samplers()
|
||||||
|
|
||||||
shared.args.loader = loader
|
shared.args.loader = loader
|
||||||
if loader != 'llama.cpp' and shared.args.ctx_size == 0:
|
|
||||||
shared.args.ctx_size = 8192
|
|
||||||
|
|
||||||
output = load_func_map[loader](model_name)
|
output = load_func_map[loader](model_name)
|
||||||
if type(output) is tuple:
|
if type(output) is tuple:
|
||||||
model, tokenizer = output
|
model, tokenizer = output
|
||||||
|
|
@ -67,8 +54,6 @@ def load_model(model_name, loader=None):
|
||||||
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp':
|
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp':
|
||||||
if shared.args.ctx_size > 0:
|
if shared.args.ctx_size > 0:
|
||||||
shared.settings['truncation_length'] = shared.args.ctx_size
|
shared.settings['truncation_length'] = shared.args.ctx_size
|
||||||
elif loader == 'llama.cpp' and hasattr(model, 'n_ctx') and model.n_ctx:
|
|
||||||
shared.settings['truncation_length'] = model.n_ctx
|
|
||||||
|
|
||||||
shared.is_multimodal = False
|
shared.is_multimodal = False
|
||||||
if loader.lower() in ('exllamav3', 'llama.cpp') and hasattr(model, 'is_multimodal'):
|
if loader.lower() in ('exllamav3', 'llama.cpp') and hasattr(model, 'is_multimodal'):
|
||||||
|
|
@ -76,7 +61,8 @@ def load_model(model_name, loader=None):
|
||||||
|
|
||||||
logger.info(f"Loaded \"{model_name}\" in {(time.time()-t0):.2f} seconds.")
|
logger.info(f"Loaded \"{model_name}\" in {(time.time()-t0):.2f} seconds.")
|
||||||
logger.info(f"LOADER: \"{loader}\"")
|
logger.info(f"LOADER: \"{loader}\"")
|
||||||
logger.info(f"CONTEXT LENGTH: {shared.settings['truncation_length']}")
|
logger.info(f"TRUNCATION LENGTH: {shared.settings['truncation_length']}")
|
||||||
|
logger.info(f"INSTRUCTION TEMPLATE: \"{metadata['instruction_template']}\"")
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -168,10 +154,7 @@ def unload_model_if_idle():
|
||||||
while True:
|
while True:
|
||||||
shared.generation_lock.acquire()
|
shared.generation_lock.acquire()
|
||||||
try:
|
try:
|
||||||
with _generation_count_lock:
|
if time.time() - last_generation_time > shared.args.idle_timeout * 60:
|
||||||
is_active = active_generation_count > 0
|
|
||||||
|
|
||||||
if not is_active and time.time() - last_generation_time > shared.args.idle_timeout * 60:
|
|
||||||
if shared.model is not None:
|
if shared.model is not None:
|
||||||
logger.info("Unloading the model for inactivity.")
|
logger.info("Unloading the model for inactivity.")
|
||||||
unload_model(keep_model_name=True)
|
unload_model(keep_model_name=True)
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,10 @@ import re
|
||||||
from math import floor
|
from math import floor
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from modules import loaders, metadata_gguf, shared
|
from modules import chat, loaders, metadata_gguf, shared, ui
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.utils import resolve_model_path
|
from modules.utils import resolve_model_path
|
||||||
|
|
||||||
|
|
@ -15,6 +16,9 @@ def get_fallback_settings():
|
||||||
return {
|
return {
|
||||||
'bf16': False,
|
'bf16': False,
|
||||||
'ctx_size': 8192,
|
'ctx_size': 8192,
|
||||||
|
'rope_freq_base': 0,
|
||||||
|
'compress_pos_emb': 1,
|
||||||
|
'alpha_value': 1,
|
||||||
'truncation_length': shared.settings['truncation_length'],
|
'truncation_length': shared.settings['truncation_length'],
|
||||||
'truncation_length_info': shared.settings['truncation_length'],
|
'truncation_length_info': shared.settings['truncation_length'],
|
||||||
'skip_special_tokens': shared.settings['skip_special_tokens'],
|
'skip_special_tokens': shared.settings['skip_special_tokens'],
|
||||||
|
|
@ -23,14 +27,18 @@ def get_fallback_settings():
|
||||||
|
|
||||||
def get_model_metadata(model):
|
def get_model_metadata(model):
|
||||||
model_path = resolve_model_path(model)
|
model_path = resolve_model_path(model)
|
||||||
|
model_settings = {}
|
||||||
|
|
||||||
# Fallback settings
|
# Get settings from user_data/models/config.yaml and user_data/models/config-user.yaml
|
||||||
model_settings = get_fallback_settings()
|
settings = shared.model_config
|
||||||
|
for pat in settings:
|
||||||
|
if re.match(pat.lower(), Path(model).name.lower()):
|
||||||
|
for k in settings[pat]:
|
||||||
|
model_settings[k] = settings[pat][k]
|
||||||
|
|
||||||
path = model_path / 'config.json'
|
path = model_path / 'config.json'
|
||||||
if path.exists():
|
if path.exists():
|
||||||
with open(path, 'r', encoding='utf-8') as f:
|
hf_metadata = json.loads(open(path, 'r', encoding='utf-8').read())
|
||||||
hf_metadata = json.loads(f.read())
|
|
||||||
else:
|
else:
|
||||||
hf_metadata = None
|
hf_metadata = None
|
||||||
|
|
||||||
|
|
@ -60,8 +68,14 @@ def get_model_metadata(model):
|
||||||
|
|
||||||
for k in metadata:
|
for k in metadata:
|
||||||
if k.endswith('.context_length'):
|
if k.endswith('.context_length'):
|
||||||
model_settings['ctx_size'] = 0
|
model_settings['ctx_size'] = min(metadata[k], 8192)
|
||||||
model_settings['truncation_length_info'] = metadata[k]
|
model_settings['truncation_length_info'] = metadata[k]
|
||||||
|
elif k.endswith('rope.freq_base'):
|
||||||
|
model_settings['rope_freq_base'] = metadata[k]
|
||||||
|
elif k.endswith('rope.scale_linear'):
|
||||||
|
model_settings['compress_pos_emb'] = metadata[k]
|
||||||
|
elif k.endswith('rope.scaling.factor'):
|
||||||
|
model_settings['compress_pos_emb'] = metadata[k]
|
||||||
elif k.endswith('.block_count'):
|
elif k.endswith('.block_count'):
|
||||||
model_settings['gpu_layers'] = -1
|
model_settings['gpu_layers'] = -1
|
||||||
model_settings['max_gpu_layers'] = metadata[k] + 1
|
model_settings['max_gpu_layers'] = metadata[k] + 1
|
||||||
|
|
@ -89,7 +103,7 @@ def get_model_metadata(model):
|
||||||
else:
|
else:
|
||||||
# Transformers metadata
|
# Transformers metadata
|
||||||
if hf_metadata is not None:
|
if hf_metadata is not None:
|
||||||
metadata = hf_metadata
|
metadata = json.loads(open(path, 'r', encoding='utf-8').read())
|
||||||
if 'pretrained_config' in metadata:
|
if 'pretrained_config' in metadata:
|
||||||
metadata = metadata['pretrained_config']
|
metadata = metadata['pretrained_config']
|
||||||
|
|
||||||
|
|
@ -106,6 +120,15 @@ def get_model_metadata(model):
|
||||||
model_settings['ctx_size'] = min(value, 8192)
|
model_settings['ctx_size'] = min(value, 8192)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if 'rope_theta' in metadata:
|
||||||
|
model_settings['rope_freq_base'] = metadata['rope_theta']
|
||||||
|
elif 'attn_config' in metadata and 'rope_theta' in metadata['attn_config']:
|
||||||
|
model_settings['rope_freq_base'] = metadata['attn_config']['rope_theta']
|
||||||
|
|
||||||
|
if 'rope_scaling' in metadata and isinstance(metadata['rope_scaling'], dict) and all(key in metadata['rope_scaling'] for key in ('type', 'factor')):
|
||||||
|
if metadata['rope_scaling']['type'] == 'linear':
|
||||||
|
model_settings['compress_pos_emb'] = metadata['rope_scaling']['factor']
|
||||||
|
|
||||||
if 'torch_dtype' in metadata and metadata['torch_dtype'] == 'bfloat16':
|
if 'torch_dtype' in metadata and metadata['torch_dtype'] == 'bfloat16':
|
||||||
model_settings['bf16'] = True
|
model_settings['bf16'] = True
|
||||||
|
|
||||||
|
|
@ -130,8 +153,7 @@ def get_model_metadata(model):
|
||||||
|
|
||||||
# 3. Fall back to tokenizer_config.json metadata
|
# 3. Fall back to tokenizer_config.json metadata
|
||||||
if path.exists():
|
if path.exists():
|
||||||
with open(path, 'r', encoding='utf-8') as f:
|
metadata = json.loads(open(path, 'r', encoding='utf-8').read())
|
||||||
metadata = json.loads(f.read())
|
|
||||||
|
|
||||||
# Only read from metadata if we haven't already loaded from .jinja or .json
|
# Only read from metadata if we haven't already loaded from .jinja or .json
|
||||||
if template is None and 'chat_template' in metadata:
|
if template is None and 'chat_template' in metadata:
|
||||||
|
|
@ -160,6 +182,10 @@ def get_model_metadata(model):
|
||||||
if 'instruction_template' not in model_settings:
|
if 'instruction_template' not in model_settings:
|
||||||
model_settings['instruction_template'] = 'Alpaca'
|
model_settings['instruction_template'] = 'Alpaca'
|
||||||
|
|
||||||
|
# Ignore rope_freq_base if set to the default value
|
||||||
|
if 'rope_freq_base' in model_settings and model_settings['rope_freq_base'] == 10000:
|
||||||
|
model_settings.pop('rope_freq_base')
|
||||||
|
|
||||||
# Apply user settings from user_data/models/config-user.yaml
|
# Apply user settings from user_data/models/config-user.yaml
|
||||||
settings = shared.user_config
|
settings = shared.user_config
|
||||||
for pat in settings:
|
for pat in settings:
|
||||||
|
|
@ -173,7 +199,7 @@ def get_model_metadata(model):
|
||||||
|
|
||||||
# Load instruction template if defined by name rather than by value
|
# Load instruction template if defined by name rather than by value
|
||||||
if model_settings['instruction_template'] != 'Custom (obtained from model metadata)':
|
if model_settings['instruction_template'] != 'Custom (obtained from model metadata)':
|
||||||
model_settings['instruction_template_str'] = load_instruction_template(model_settings['instruction_template'])
|
model_settings['instruction_template_str'] = chat.load_instruction_template(model_settings['instruction_template'])
|
||||||
|
|
||||||
return model_settings
|
return model_settings
|
||||||
|
|
||||||
|
|
@ -202,7 +228,7 @@ def update_model_parameters(state, initial=False):
|
||||||
'''
|
'''
|
||||||
UI: update the command-line arguments based on the interface values
|
UI: update the command-line arguments based on the interface values
|
||||||
'''
|
'''
|
||||||
elements = loaders.list_model_elements() # the names of the parameters
|
elements = ui.list_model_elements() # the names of the parameters
|
||||||
|
|
||||||
for i, element in enumerate(elements):
|
for i, element in enumerate(elements):
|
||||||
if element not in state:
|
if element not in state:
|
||||||
|
|
@ -222,7 +248,6 @@ def apply_model_settings_to_state(model, state):
|
||||||
'''
|
'''
|
||||||
UI: update the state variable with the model settings
|
UI: update the state variable with the model settings
|
||||||
'''
|
'''
|
||||||
import gradio as gr
|
|
||||||
model_settings = get_model_metadata(model)
|
model_settings = get_model_metadata(model)
|
||||||
if 'loader' in model_settings:
|
if 'loader' in model_settings:
|
||||||
loader = model_settings.pop('loader')
|
loader = model_settings.pop('loader')
|
||||||
|
|
@ -265,7 +290,7 @@ def save_model_settings(model, state):
|
||||||
if model_regex not in user_config:
|
if model_regex not in user_config:
|
||||||
user_config[model_regex] = {}
|
user_config[model_regex] = {}
|
||||||
|
|
||||||
for k in loaders.list_model_elements():
|
for k in ui.list_model_elements():
|
||||||
if k == 'loader' or k in loaders.loaders_and_params[state['loader']]:
|
if k == 'loader' or k in loaders.loaders_and_params[state['loader']]:
|
||||||
user_config[model_regex][k] = state[k]
|
user_config[model_regex][k] = state[k]
|
||||||
|
|
||||||
|
|
@ -394,108 +419,3 @@ def update_gpu_layers_and_vram(loader, model, gpu_layers, ctx_size, cache_type):
|
||||||
|
|
||||||
vram_usage = estimate_vram(model, gpu_layers, ctx_size, cache_type)
|
vram_usage = estimate_vram(model, gpu_layers, ctx_size, cache_type)
|
||||||
return f"<div id=\"vram-info\"'>Estimated VRAM to load the model: <span class=\"value\">{vram_usage:.0f} MiB</span></div>"
|
return f"<div id=\"vram-info\"'>Estimated VRAM to load the model: <span class=\"value\">{vram_usage:.0f} MiB</span></div>"
|
||||||
|
|
||||||
|
|
||||||
def load_instruction_template(template):
|
|
||||||
if template == 'None':
|
|
||||||
return ''
|
|
||||||
|
|
||||||
for name in (template, 'Alpaca'):
|
|
||||||
path = shared.user_data_dir / 'instruction-templates' / f'{name}.yaml'
|
|
||||||
try:
|
|
||||||
with open(path, 'r', encoding='utf-8') as f:
|
|
||||||
file_contents = f.read()
|
|
||||||
except FileNotFoundError:
|
|
||||||
if name == template:
|
|
||||||
logger.warning(f"Instruction template '{template}' not found, falling back to Alpaca")
|
|
||||||
continue
|
|
||||||
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
return ''
|
|
||||||
data = yaml.safe_load(file_contents)
|
|
||||||
if 'instruction_template' in data:
|
|
||||||
return data['instruction_template']
|
|
||||||
else:
|
|
||||||
return _jinja_template_from_old_format(data)
|
|
||||||
|
|
||||||
|
|
||||||
def _jinja_template_from_old_format(params, verbose=False):
|
|
||||||
MASTER_TEMPLATE = """
|
|
||||||
{%- set ns = namespace(found=false) -%}
|
|
||||||
{%- for message in messages -%}
|
|
||||||
{%- if message['role'] == 'system' -%}
|
|
||||||
{%- set ns.found = true -%}
|
|
||||||
{%- endif -%}
|
|
||||||
{%- endfor -%}
|
|
||||||
{%- if not ns.found -%}
|
|
||||||
{{- '<|PRE-SYSTEM|>' + '<|SYSTEM-MESSAGE|>' + '<|POST-SYSTEM|>' -}}
|
|
||||||
{%- endif %}
|
|
||||||
{%- for message in messages %}
|
|
||||||
{%- if message['role'] == 'system' -%}
|
|
||||||
{{- '<|PRE-SYSTEM|>' + message['content'] + '<|POST-SYSTEM|>' -}}
|
|
||||||
{%- else -%}
|
|
||||||
{%- if message['role'] == 'user' -%}
|
|
||||||
{{-'<|PRE-USER|>' + message['content'] + '<|POST-USER|>'-}}
|
|
||||||
{%- else -%}
|
|
||||||
{{-'<|PRE-ASSISTANT|>' + message['content'] + '<|POST-ASSISTANT|>' -}}
|
|
||||||
{%- endif -%}
|
|
||||||
{%- endif -%}
|
|
||||||
{%- endfor -%}
|
|
||||||
{%- if add_generation_prompt -%}
|
|
||||||
{{-'<|PRE-ASSISTANT-GENERATE|>'-}}
|
|
||||||
{%- endif -%}
|
|
||||||
"""
|
|
||||||
|
|
||||||
if 'context' in params and '<|system-message|>' in params['context']:
|
|
||||||
pre_system = params['context'].split('<|system-message|>')[0]
|
|
||||||
post_system = params['context'].split('<|system-message|>')[1]
|
|
||||||
else:
|
|
||||||
pre_system = ''
|
|
||||||
post_system = ''
|
|
||||||
|
|
||||||
pre_user = params['turn_template'].split('<|user-message|>')[0].replace('<|user|>', params['user'])
|
|
||||||
post_user = params['turn_template'].split('<|user-message|>')[1].split('<|bot|>')[0]
|
|
||||||
|
|
||||||
pre_assistant = '<|bot|>' + params['turn_template'].split('<|bot-message|>')[0].split('<|bot|>')[1]
|
|
||||||
pre_assistant = pre_assistant.replace('<|bot|>', params['bot'])
|
|
||||||
post_assistant = params['turn_template'].split('<|bot-message|>')[1]
|
|
||||||
|
|
||||||
def preprocess(string):
|
|
||||||
return string.replace('\n', '\\n').replace('\'', '\\\'')
|
|
||||||
|
|
||||||
pre_system = preprocess(pre_system)
|
|
||||||
post_system = preprocess(post_system)
|
|
||||||
pre_user = preprocess(pre_user)
|
|
||||||
post_user = preprocess(post_user)
|
|
||||||
pre_assistant = preprocess(pre_assistant)
|
|
||||||
post_assistant = preprocess(post_assistant)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(
|
|
||||||
'\n',
|
|
||||||
repr(pre_system) + '\n',
|
|
||||||
repr(post_system) + '\n',
|
|
||||||
repr(pre_user) + '\n',
|
|
||||||
repr(post_user) + '\n',
|
|
||||||
repr(pre_assistant) + '\n',
|
|
||||||
repr(post_assistant) + '\n',
|
|
||||||
)
|
|
||||||
|
|
||||||
result = MASTER_TEMPLATE
|
|
||||||
if 'system_message' in params:
|
|
||||||
result = result.replace('<|SYSTEM-MESSAGE|>', preprocess(params['system_message']))
|
|
||||||
else:
|
|
||||||
result = result.replace('<|SYSTEM-MESSAGE|>', '')
|
|
||||||
|
|
||||||
result = result.replace('<|PRE-SYSTEM|>', pre_system)
|
|
||||||
result = result.replace('<|POST-SYSTEM|>', post_system)
|
|
||||||
result = result.replace('<|PRE-USER|>', pre_user)
|
|
||||||
result = result.replace('<|POST-USER|>', post_user)
|
|
||||||
result = result.replace('<|PRE-ASSISTANT|>', pre_assistant)
|
|
||||||
result = result.replace('<|PRE-ASSISTANT-GENERATE|>', pre_assistant.rstrip(' '))
|
|
||||||
result = result.replace('<|POST-ASSISTANT|>', post_assistant)
|
|
||||||
|
|
||||||
result = result.strip()
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
|
||||||
|
|
@ -16,10 +16,9 @@ default_preset_values = {
|
||||||
'dynatemp_exponent': 1,
|
'dynatemp_exponent': 1,
|
||||||
'smoothing_factor': 0,
|
'smoothing_factor': 0,
|
||||||
'smoothing_curve': 1,
|
'smoothing_curve': 1,
|
||||||
|
'min_p': 0,
|
||||||
'top_p': 1,
|
'top_p': 1,
|
||||||
'top_k': 0,
|
'top_k': 0,
|
||||||
'min_p': 0,
|
|
||||||
'top_n_sigma': 0,
|
|
||||||
'typical_p': 1,
|
'typical_p': 1,
|
||||||
'xtc_threshold': 0.1,
|
'xtc_threshold': 0.1,
|
||||||
'xtc_probability': 0,
|
'xtc_probability': 0,
|
||||||
|
|
@ -27,6 +26,7 @@ default_preset_values = {
|
||||||
'eta_cutoff': 0,
|
'eta_cutoff': 0,
|
||||||
'tfs': 1,
|
'tfs': 1,
|
||||||
'top_a': 0,
|
'top_a': 0,
|
||||||
|
'top_n_sigma': 0,
|
||||||
'adaptive_target': 0,
|
'adaptive_target': 0,
|
||||||
'adaptive_decay': 0.9,
|
'adaptive_decay': 0.9,
|
||||||
'dry_multiplier': 0,
|
'dry_multiplier': 0,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from modules import shared, utils
|
from modules import shared, utils
|
||||||
from modules.utils import sanitize_filename
|
|
||||||
from modules.text_generation import get_encoded_length
|
from modules.text_generation import get_encoded_length
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -19,7 +18,6 @@ def load_prompt(fname):
|
||||||
|
|
||||||
return initial_content
|
return initial_content
|
||||||
|
|
||||||
fname = sanitize_filename(fname)
|
|
||||||
file_path = shared.user_data_dir / 'logs' / 'notebook' / f'{fname}.txt'
|
file_path = shared.user_data_dir / 'logs' / 'notebook' / f'{fname}.txt'
|
||||||
if file_path.exists():
|
if file_path.exists():
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
|
|
||||||
|
|
@ -1,101 +0,0 @@
|
||||||
import html as html_module
|
|
||||||
|
|
||||||
# Thinking block format definitions: (start_tag, end_tag, content_start_tag)
|
|
||||||
# Use None for start_tag to match from beginning (end-only formats should be listed last)
|
|
||||||
THINKING_FORMATS = [
|
|
||||||
('<think>', '</think>', None),
|
|
||||||
('<|channel|>analysis<|message|>', '<|end|>', '<|channel|>final<|message|>'),
|
|
||||||
('<|channel|>commentary<|message|>', '<|end|>', '<|channel|>final<|message|>'),
|
|
||||||
('<seed:think>', '</seed:think>', None),
|
|
||||||
('<|channel>thought', '<channel|>', None), # Gemma 4
|
|
||||||
('<|think|>', '<|end|>', '<|content|>'), # Solar Open
|
|
||||||
# ('Thinking Process:', '</think>', None), # Qwen3.5 verbose thinking outside tags -- removed: too prone to false positives in streaming
|
|
||||||
(None, '</think>', None), # End-only variant (e.g., Qwen3-next)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def extract_reasoning(text, html_escaped=False):
|
|
||||||
"""Extract reasoning/thinking blocks from the beginning of a string.
|
|
||||||
|
|
||||||
When html_escaped=True, tags are HTML-escaped before searching
|
|
||||||
(for use on already-escaped UI strings).
|
|
||||||
|
|
||||||
Returns (reasoning_content, final_content) where reasoning_content is
|
|
||||||
None if no thinking block is found.
|
|
||||||
"""
|
|
||||||
if not text:
|
|
||||||
return None, text
|
|
||||||
|
|
||||||
esc = html_module.escape if html_escaped else lambda s: s
|
|
||||||
|
|
||||||
for start_tag, end_tag, content_tag in THINKING_FORMATS:
|
|
||||||
end_esc = esc(end_tag)
|
|
||||||
content_esc = esc(content_tag) if content_tag else None
|
|
||||||
|
|
||||||
if start_tag is None:
|
|
||||||
# End-only format: require end tag, start from beginning
|
|
||||||
end_pos = text.find(end_esc)
|
|
||||||
if end_pos == -1:
|
|
||||||
continue
|
|
||||||
thought_start = 0
|
|
||||||
else:
|
|
||||||
# Normal format: require start tag
|
|
||||||
start_esc = esc(start_tag)
|
|
||||||
start_pos = text.find(start_esc)
|
|
||||||
if start_pos == -1:
|
|
||||||
# During streaming, the start tag may be arriving partially.
|
|
||||||
# If the text is a prefix of a start tag, return empty content
|
|
||||||
# to prevent the partial tag from leaking.
|
|
||||||
stripped = text.strip()
|
|
||||||
if stripped and start_esc.startswith(stripped):
|
|
||||||
return '', ''
|
|
||||||
continue
|
|
||||||
thought_start = start_pos + len(start_esc)
|
|
||||||
end_pos = text.find(end_esc, thought_start)
|
|
||||||
|
|
||||||
if end_pos == -1:
|
|
||||||
# End tag missing - check if content tag can serve as fallback
|
|
||||||
if content_esc:
|
|
||||||
content_pos = text.find(content_esc, thought_start)
|
|
||||||
if content_pos != -1:
|
|
||||||
thought_end = content_pos
|
|
||||||
content_start = content_pos + len(content_esc)
|
|
||||||
else:
|
|
||||||
thought_end = len(text)
|
|
||||||
content_start = len(text)
|
|
||||||
else:
|
|
||||||
thought_end = len(text)
|
|
||||||
content_start = len(text)
|
|
||||||
else:
|
|
||||||
thought_end = end_pos
|
|
||||||
if content_esc:
|
|
||||||
content_pos = text.find(content_esc, end_pos)
|
|
||||||
if content_pos != -1:
|
|
||||||
content_start = content_pos + len(content_esc)
|
|
||||||
else:
|
|
||||||
# Content tag not present yet. In GPT-OSS the region
|
|
||||||
# between <|end|> and the content tag contains internal
|
|
||||||
# markup (<|start|>assistant…) that must not be shown.
|
|
||||||
# Suppress it to prevent tag leaks during streaming.
|
|
||||||
remainder = text[end_pos + len(end_esc):].lstrip()
|
|
||||||
framing_token = esc('<|start|>')
|
|
||||||
if not remainder or remainder.startswith(framing_token) or framing_token.startswith(remainder):
|
|
||||||
content_start = len(text)
|
|
||||||
else:
|
|
||||||
content_start = end_pos + len(end_esc)
|
|
||||||
else:
|
|
||||||
content_start = end_pos + len(end_esc)
|
|
||||||
|
|
||||||
return text[thought_start:thought_end], text[content_start:].lstrip()
|
|
||||||
|
|
||||||
# Handle standalone GPT-OSS final channel marker without a preceding
|
|
||||||
# analysis/commentary block (the model skipped thinking entirely).
|
|
||||||
for marker in ['<|start|>assistant<|channel|>final<|message|>', '<|channel|>final<|message|>']:
|
|
||||||
marker_esc = esc(marker)
|
|
||||||
pos = text.find(marker_esc)
|
|
||||||
if pos != -1:
|
|
||||||
before = text[:pos].strip()
|
|
||||||
after = text[pos + len(marker_esc):]
|
|
||||||
return (before if before else None), after
|
|
||||||
|
|
||||||
return None, text
|
|
||||||
|
|
@ -47,7 +47,7 @@ parser = argparse.ArgumentParser(description="Text Generation Web UI", conflict_
|
||||||
# Basic settings
|
# Basic settings
|
||||||
group = parser.add_argument_group('Basic settings')
|
group = parser.add_argument_group('Basic settings')
|
||||||
group.add_argument('--user-data-dir', type=str, default=str(user_data_dir), help='Path to the user data directory. Default: auto-detected.')
|
group.add_argument('--user-data-dir', type=str, default=str(user_data_dir), help='Path to the user data directory. Default: auto-detected.')
|
||||||
group.add_argument('--multi-user', action='store_true', help='Multi-user mode. Chat histories are not saved or automatically loaded. Best suited for small trusted teams.')
|
group.add_argument('--multi-user', action='store_true', help='Multi-user mode. Chat histories are not saved or automatically loaded. Warning: this is likely not safe for sharing publicly.')
|
||||||
group.add_argument('--model', type=str, help='Name of the model to load by default.')
|
group.add_argument('--model', type=str, help='Name of the model to load by default.')
|
||||||
group.add_argument('--lora', type=str, nargs='+', help='The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.')
|
group.add_argument('--lora', type=str, nargs='+', help='The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.')
|
||||||
group.add_argument('--model-dir', type=str, default=str(user_data_dir / 'models'), help='Path to directory with all the models.')
|
group.add_argument('--model-dir', type=str, default=str(user_data_dir / 'models'), help='Path to directory with all the models.')
|
||||||
|
|
@ -76,7 +76,7 @@ group.add_argument('--loader', type=str, help='Choose the model loader manually,
|
||||||
|
|
||||||
# Cache
|
# Cache
|
||||||
group = parser.add_argument_group('Context and cache')
|
group = parser.add_argument_group('Context and cache')
|
||||||
group.add_argument('--ctx-size', '--n_ctx', '--max_seq_len', type=int, default=0, metavar='N', help='Context size in tokens. 0 = auto for llama.cpp (requires gpu-layers=-1), 8192 for other loaders.')
|
group.add_argument('--ctx-size', '--n_ctx', '--max_seq_len', type=int, default=8192, metavar='N', help='Context size in tokens. llama.cpp: 0 = auto if gpu-layers is also -1.')
|
||||||
group.add_argument('--cache-type', '--cache_type', type=str, default='fp16', metavar='N', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8).')
|
group.add_argument('--cache-type', '--cache_type', type=str, default='fp16', metavar='N', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8).')
|
||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
|
|
@ -101,16 +101,15 @@ group.add_argument('--tensor-split', type=str, default=None, help='Split the mod
|
||||||
group.add_argument('--row-split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.')
|
group.add_argument('--row-split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.')
|
||||||
group.add_argument('--no-mmap', action='store_true', help='Prevent mmap from being used.')
|
group.add_argument('--no-mmap', action='store_true', help='Prevent mmap from being used.')
|
||||||
group.add_argument('--mlock', action='store_true', help='Force the system to keep the model in RAM.')
|
group.add_argument('--mlock', action='store_true', help='Force the system to keep the model in RAM.')
|
||||||
group.add_argument('--no-kv-offload', action='store_true', help='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces performance.')
|
group.add_argument('--no-kv-offload', action='store_true', help='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.')
|
||||||
group.add_argument('--batch-size', type=int, default=1024, help='Maximum number of prompt tokens to batch together when calling llama-server. This is the application level batch size.')
|
group.add_argument('--batch-size', type=int, default=1024, help='Maximum number of prompt tokens to batch together when calling llama-server. This is the application level batch size.')
|
||||||
group.add_argument('--ubatch-size', type=int, default=1024, help='Maximum number of prompt tokens to batch together when calling llama-server. This is the max physical batch size for computation (device level).')
|
group.add_argument('--ubatch-size', type=int, default=1024, help='Maximum number of prompt tokens to batch together when calling llama-server. This is the max physical batch size for computation (device level).')
|
||||||
group.add_argument('--threads', type=int, default=0, help='Number of threads to use.')
|
group.add_argument('--threads', type=int, default=0, help='Number of threads to use.')
|
||||||
group.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.')
|
group.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.')
|
||||||
group.add_argument('--numa', action='store_true', help='Activate NUMA task allocation for llama.cpp.')
|
group.add_argument('--numa', action='store_true', help='Activate NUMA task allocation for llama.cpp.')
|
||||||
group.add_argument('--parallel', type=int, default=1, help='Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set ctx_size to 32768.')
|
group.add_argument('--parallel', type=int, default=1, help='Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set ctx_size to 32768.')
|
||||||
group.add_argument('--fit-target', type=str, default='512', help='Target VRAM margin per device for auto GPU layers, comma-separated list of values in MiB. A single value is broadcast across all devices.')
|
group.add_argument('--fit-target', type=str, default='1024', help='Target VRAM margin per device for auto GPU layers, comma-separated list of values in MiB. A single value is broadcast across all devices. Default: 1024.')
|
||||||
group.add_argument('--extra-flags', type=str, default=None, help='Extra flags to pass to llama-server. Example: "--jinja --rpc 192.168.1.100:50052"')
|
group.add_argument('--extra-flags', type=str, default=None, help='Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"')
|
||||||
group.add_argument('--ik', action='store_true', help='Use ik_llama.cpp instead of upstream llama.cpp. Requires the ik_llama_cpp_binaries package to be installed.')
|
|
||||||
|
|
||||||
# Transformers/Accelerate
|
# Transformers/Accelerate
|
||||||
group = parser.add_argument_group('Transformers/Accelerate')
|
group = parser.add_argument_group('Transformers/Accelerate')
|
||||||
|
|
@ -140,6 +139,12 @@ group.add_argument('--enable-tp', '--enable_tp', action='store_true', help='Enab
|
||||||
group.add_argument('--tp-backend', type=str, default='native', help='The backend for tensor parallelism. Valid options: native, nccl. Default: native.')
|
group.add_argument('--tp-backend', type=str, default='native', help='The backend for tensor parallelism. Valid options: native, nccl. Default: native.')
|
||||||
group.add_argument('--cfg-cache', action='store_true', help='Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.')
|
group.add_argument('--cfg-cache', action='store_true', help='Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.')
|
||||||
|
|
||||||
|
# RoPE
|
||||||
|
group = parser.add_argument_group('RoPE')
|
||||||
|
group.add_argument('--alpha_value', type=float, default=1, help='Positional embeddings alpha factor for NTK RoPE scaling. Use either this or compress_pos_emb, not both.')
|
||||||
|
group.add_argument('--rope_freq_base', type=int, default=0, help='If greater than 0, will be used instead of alpha_value. Those two are related by rope_freq_base = 10000 * alpha_value ^ (64 / 63).')
|
||||||
|
group.add_argument('--compress_pos_emb', type=int, default=1, help="Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.")
|
||||||
|
|
||||||
# Gradio
|
# Gradio
|
||||||
group = parser.add_argument_group('Gradio')
|
group = parser.add_argument_group('Gradio')
|
||||||
group.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
|
group.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
|
||||||
|
|
@ -157,8 +162,8 @@ group.add_argument('--portable', action='store_true', help='Hide features not av
|
||||||
|
|
||||||
# API
|
# API
|
||||||
group = parser.add_argument_group('API')
|
group = parser.add_argument_group('API')
|
||||||
group.add_argument('--api', action='store_true', help='Enable the API server.')
|
group.add_argument('--api', action='store_true', help='Enable the API extension.')
|
||||||
group.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudflare.')
|
group.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
|
||||||
group.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None)
|
group.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None)
|
||||||
group.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.')
|
group.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.')
|
||||||
group.add_argument('--api-key', type=str, default='', help='API authentication key.')
|
group.add_argument('--api-key', type=str, default='', help='API authentication key.')
|
||||||
|
|
@ -176,10 +181,9 @@ group.add_argument('--dynatemp-high', type=float, default=_d['dynatemp_high'], m
|
||||||
group.add_argument('--dynatemp-exponent', type=float, default=_d['dynatemp_exponent'], metavar='N', help='Dynamic temperature exponent')
|
group.add_argument('--dynatemp-exponent', type=float, default=_d['dynatemp_exponent'], metavar='N', help='Dynamic temperature exponent')
|
||||||
group.add_argument('--smoothing-factor', type=float, default=_d['smoothing_factor'], metavar='N', help='Smoothing factor')
|
group.add_argument('--smoothing-factor', type=float, default=_d['smoothing_factor'], metavar='N', help='Smoothing factor')
|
||||||
group.add_argument('--smoothing-curve', type=float, default=_d['smoothing_curve'], metavar='N', help='Smoothing curve')
|
group.add_argument('--smoothing-curve', type=float, default=_d['smoothing_curve'], metavar='N', help='Smoothing curve')
|
||||||
group.add_argument('--top-p', type=float, default=0.95, metavar='N', help='Top P')
|
|
||||||
group.add_argument('--top-k', type=int, default=_d['top_k'], metavar='N', help='Top K')
|
|
||||||
group.add_argument('--min-p', type=float, default=_d['min_p'], metavar='N', help='Min P')
|
group.add_argument('--min-p', type=float, default=_d['min_p'], metavar='N', help='Min P')
|
||||||
group.add_argument('--top-n-sigma', type=float, default=_d['top_n_sigma'], metavar='N', help='Top N Sigma')
|
group.add_argument('--top-p', type=float, default=_d['top_p'], metavar='N', help='Top P')
|
||||||
|
group.add_argument('--top-k', type=int, default=_d['top_k'], metavar='N', help='Top K')
|
||||||
group.add_argument('--typical-p', type=float, default=_d['typical_p'], metavar='N', help='Typical P')
|
group.add_argument('--typical-p', type=float, default=_d['typical_p'], metavar='N', help='Typical P')
|
||||||
group.add_argument('--xtc-threshold', type=float, default=_d['xtc_threshold'], metavar='N', help='XTC threshold')
|
group.add_argument('--xtc-threshold', type=float, default=_d['xtc_threshold'], metavar='N', help='XTC threshold')
|
||||||
group.add_argument('--xtc-probability', type=float, default=_d['xtc_probability'], metavar='N', help='XTC probability')
|
group.add_argument('--xtc-probability', type=float, default=_d['xtc_probability'], metavar='N', help='XTC probability')
|
||||||
|
|
@ -187,6 +191,7 @@ group.add_argument('--epsilon-cutoff', type=float, default=_d['epsilon_cutoff'],
|
||||||
group.add_argument('--eta-cutoff', type=float, default=_d['eta_cutoff'], metavar='N', help='Eta cutoff')
|
group.add_argument('--eta-cutoff', type=float, default=_d['eta_cutoff'], metavar='N', help='Eta cutoff')
|
||||||
group.add_argument('--tfs', type=float, default=_d['tfs'], metavar='N', help='TFS')
|
group.add_argument('--tfs', type=float, default=_d['tfs'], metavar='N', help='TFS')
|
||||||
group.add_argument('--top-a', type=float, default=_d['top_a'], metavar='N', help='Top A')
|
group.add_argument('--top-a', type=float, default=_d['top_a'], metavar='N', help='Top A')
|
||||||
|
group.add_argument('--top-n-sigma', type=float, default=_d['top_n_sigma'], metavar='N', help='Top N Sigma')
|
||||||
group.add_argument('--adaptive-target', type=float, default=_d['adaptive_target'], metavar='N', help='Adaptive target')
|
group.add_argument('--adaptive-target', type=float, default=_d['adaptive_target'], metavar='N', help='Adaptive target')
|
||||||
group.add_argument('--adaptive-decay', type=float, default=_d['adaptive_decay'], metavar='N', help='Adaptive decay')
|
group.add_argument('--adaptive-decay', type=float, default=_d['adaptive_decay'], metavar='N', help='Adaptive decay')
|
||||||
group.add_argument('--dry-multiplier', type=float, default=_d['dry_multiplier'], metavar='N', help='DRY multiplier')
|
group.add_argument('--dry-multiplier', type=float, default=_d['dry_multiplier'], metavar='N', help='DRY multiplier')
|
||||||
|
|
@ -258,10 +263,8 @@ settings = {
|
||||||
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>". Reply directly, without starting the reply with the character name.\n\n<|prompt|>',
|
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>". Reply directly, without starting the reply with the character name.\n\n<|prompt|>',
|
||||||
'enable_web_search': False,
|
'enable_web_search': False,
|
||||||
'web_search_pages': 3,
|
'web_search_pages': 3,
|
||||||
'selected_tools': [],
|
|
||||||
'mcp_servers': '',
|
|
||||||
'prompt-notebook': '',
|
'prompt-notebook': '',
|
||||||
'preset': 'Top-P' if (user_data_dir / 'presets/Top-P.yaml').exists() else None,
|
'preset': 'Qwen3 - Thinking' if (user_data_dir / 'presets/Qwen3 - Thinking.yaml').exists() else None,
|
||||||
'max_new_tokens': 512,
|
'max_new_tokens': 512,
|
||||||
'max_new_tokens_min': 1,
|
'max_new_tokens_min': 1,
|
||||||
'max_new_tokens_max': 4096,
|
'max_new_tokens_max': 4096,
|
||||||
|
|
@ -286,7 +289,7 @@ settings = {
|
||||||
'include_past_attachments': True,
|
'include_past_attachments': True,
|
||||||
|
|
||||||
# Generation parameters - Curve shape
|
# Generation parameters - Curve shape
|
||||||
'temperature': neutral_samplers['temperature'],
|
'temperature': 0.6,
|
||||||
'dynatemp_low': neutral_samplers['dynatemp_low'],
|
'dynatemp_low': neutral_samplers['dynatemp_low'],
|
||||||
'dynatemp_high': neutral_samplers['dynatemp_high'],
|
'dynatemp_high': neutral_samplers['dynatemp_high'],
|
||||||
'dynatemp_exponent': neutral_samplers['dynatemp_exponent'],
|
'dynatemp_exponent': neutral_samplers['dynatemp_exponent'],
|
||||||
|
|
@ -294,10 +297,9 @@ settings = {
|
||||||
'smoothing_curve': neutral_samplers['smoothing_curve'],
|
'smoothing_curve': neutral_samplers['smoothing_curve'],
|
||||||
|
|
||||||
# Generation parameters - Curve cutoff
|
# Generation parameters - Curve cutoff
|
||||||
'top_p': 0.95,
|
|
||||||
'top_k': neutral_samplers['top_k'],
|
|
||||||
'min_p': neutral_samplers['min_p'],
|
'min_p': neutral_samplers['min_p'],
|
||||||
'top_n_sigma': neutral_samplers['top_n_sigma'],
|
'top_p': 0.95,
|
||||||
|
'top_k': 20,
|
||||||
'typical_p': neutral_samplers['typical_p'],
|
'typical_p': neutral_samplers['typical_p'],
|
||||||
'xtc_threshold': neutral_samplers['xtc_threshold'],
|
'xtc_threshold': neutral_samplers['xtc_threshold'],
|
||||||
'xtc_probability': neutral_samplers['xtc_probability'],
|
'xtc_probability': neutral_samplers['xtc_probability'],
|
||||||
|
|
@ -305,6 +307,7 @@ settings = {
|
||||||
'eta_cutoff': neutral_samplers['eta_cutoff'],
|
'eta_cutoff': neutral_samplers['eta_cutoff'],
|
||||||
'tfs': neutral_samplers['tfs'],
|
'tfs': neutral_samplers['tfs'],
|
||||||
'top_a': neutral_samplers['top_a'],
|
'top_a': neutral_samplers['top_a'],
|
||||||
|
'top_n_sigma': neutral_samplers['top_n_sigma'],
|
||||||
'adaptive_target': neutral_samplers['adaptive_target'],
|
'adaptive_target': neutral_samplers['adaptive_target'],
|
||||||
'adaptive_decay': neutral_samplers['adaptive_decay'],
|
'adaptive_decay': neutral_samplers['adaptive_decay'],
|
||||||
|
|
||||||
|
|
@ -344,7 +347,7 @@ settings = {
|
||||||
'greeting': 'How can I help you today?',
|
'greeting': 'How can I help you today?',
|
||||||
'custom_system_message': '',
|
'custom_system_message': '',
|
||||||
'instruction_template_str': "{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if not ns.found -%}\n {{- '' + 'Below is an instruction that describes a task. Write a response that appropriately completes the request.' + '\\n\\n' -}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {{- '' + message['content'] + '\\n\\n' -}}\n {%- else -%}\n {%- if message['role'] == 'user' -%}\n {{-'### Instruction:\\n' + message['content'] + '\\n\\n'-}}\n {%- else -%}\n {{-'### Response:\\n' + message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{-'### Response:\\n'-}}\n{%- endif -%}",
|
'instruction_template_str': "{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if not ns.found -%}\n {{- '' + 'Below is an instruction that describes a task. Write a response that appropriately completes the request.' + '\\n\\n' -}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {{- '' + message['content'] + '\\n\\n' -}}\n {%- else -%}\n {%- if message['role'] == 'user' -%}\n {{-'### Instruction:\\n' + message['content'] + '\\n\\n'-}}\n {%- else -%}\n {{-'### Response:\\n' + message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{-'### Response:\\n'-}}\n{%- endif -%}",
|
||||||
'chat_template_str': "{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {%- if message['content'] -%}\n {{- message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- if user_bio -%}\n {{- user_bio + '\\n\\n' -}}\n {%- endif -%}\n {%- elif message['role'] == 'tool' -%}\n {{- '[Tool result: ' + message['content'] + ']\\n' -}}\n {%- elif message['role'] == 'user' -%}\n {{- name1 + ': ' + message['content'] + '\\n'-}}\n {%- elif message['tool_calls'] is defined and message['tool_calls'] -%}\n {%- for tc in message['tool_calls'] -%}\n {{- '[Calling: ' + tc['function']['name'] + '(' + tc['function']['arguments'] + ')]\\n' -}}\n {%- endfor -%}\n {%- else -%}\n {{- name2 + ': ' + message['content'] + '\\n' -}}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt %}\n {{- name2 + ':' -}}\n{%- endif %}",
|
'chat_template_str': "{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {%- if message['content'] -%}\n {{- message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- if user_bio -%}\n {{- user_bio + '\\n\\n' -}}\n {%- endif -%}\n {%- else -%}\n {%- if message['role'] == 'user' -%}\n {{- name1 + ': ' + message['content'] + '\\n'-}}\n {%- else -%}\n {{- name2 + ': ' + message['content'] + '\\n' -}}\n {%- endif -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt %}\n {{- name2 + ':' -}}\n{%- endif %}",
|
||||||
|
|
||||||
# Extensions
|
# Extensions
|
||||||
'default_extensions': [],
|
'default_extensions': [],
|
||||||
|
|
@ -364,7 +367,7 @@ settings = {
|
||||||
'image_llm_variations_prompt': 'Write a variation of the image generation prompt above. Consider the intent of the user with that prompt and write something that will likely please them, with added details. Output only the new prompt. Do not add any explanations, prefixes, or additional text.',
|
'image_llm_variations_prompt': 'Write a variation of the image generation prompt above. Consider the intent of the user with that prompt and write something that will likely please them, with added details. Output only the new prompt. Do not add any explanations, prefixes, or additional text.',
|
||||||
'image_model_menu': 'None',
|
'image_model_menu': 'None',
|
||||||
'image_dtype': 'bfloat16',
|
'image_dtype': 'bfloat16',
|
||||||
'image_attn_backend': 'sdpa',
|
'image_attn_backend': 'flash_attention_2',
|
||||||
'image_cpu_offload': False,
|
'image_cpu_offload': False,
|
||||||
'image_compile': False,
|
'image_compile': False,
|
||||||
'image_quant': 'none',
|
'image_quant': 'none',
|
||||||
|
|
@ -392,16 +395,9 @@ def do_cmd_flags_warnings():
|
||||||
if args.share:
|
if args.share:
|
||||||
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
|
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
|
||||||
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)):
|
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)):
|
||||||
logger.warning("You are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
|
logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
|
||||||
if args.multi_user:
|
if args.multi_user:
|
||||||
logger.warning(
|
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.')
|
||||||
'Multi-user mode is enabled. Known limitations:'
|
|
||||||
'\n- The Stop button stops generation for all users, not just you.'
|
|
||||||
'\n- Chat history is not saved and will be lost on page refresh.'
|
|
||||||
'\n- Only one user can generate at a time unless using a parallel-capable backend (e.g. llama.cpp with --parallel N for N > 1, or ExLlamaV3).'
|
|
||||||
'\n\nThis mode works best for small trusted teams.'
|
|
||||||
'\n\nDo not expose publicly. Grayed-out actions can easily be bypassed client-side.\n'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_image_model_cli_overrides():
|
def apply_image_model_cli_overrides():
|
||||||
|
|
@ -437,6 +433,16 @@ def fix_loader_name(name):
|
||||||
return 'TensorRT-LLM'
|
return 'TensorRT-LLM'
|
||||||
|
|
||||||
|
|
||||||
|
def add_extension(name, last=False):
|
||||||
|
if args.extensions is None:
|
||||||
|
args.extensions = [name]
|
||||||
|
elif last:
|
||||||
|
args.extensions = [x for x in args.extensions if x != name]
|
||||||
|
args.extensions.append(name)
|
||||||
|
elif name not in args.extensions:
|
||||||
|
args.extensions.append(name)
|
||||||
|
|
||||||
|
|
||||||
def is_chat():
|
def is_chat():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
@ -445,18 +451,36 @@ def load_user_config():
|
||||||
'''
|
'''
|
||||||
Loads custom model-specific settings
|
Loads custom model-specific settings
|
||||||
'''
|
'''
|
||||||
user_config = {}
|
|
||||||
if Path(f'{args.model_dir}/config-user.yaml').exists():
|
if Path(f'{args.model_dir}/config-user.yaml').exists():
|
||||||
file_content = open(f'{args.model_dir}/config-user.yaml', 'r').read().strip()
|
file_content = open(f'{args.model_dir}/config-user.yaml', 'r').read().strip()
|
||||||
|
|
||||||
if file_content:
|
if file_content:
|
||||||
user_config = yaml.safe_load(file_content)
|
user_config = yaml.safe_load(file_content)
|
||||||
|
else:
|
||||||
|
user_config = {}
|
||||||
|
else:
|
||||||
|
user_config = {}
|
||||||
|
|
||||||
return user_config
|
return user_config
|
||||||
|
|
||||||
|
|
||||||
args.loader = fix_loader_name(args.loader)
|
args.loader = fix_loader_name(args.loader)
|
||||||
|
|
||||||
|
# Activate the API extension
|
||||||
|
if args.api or args.public_api:
|
||||||
|
add_extension('openai', last=True)
|
||||||
|
|
||||||
|
# Load model-specific settings
|
||||||
|
p = Path(f'{args.model_dir}/config.yaml')
|
||||||
|
if p.exists():
|
||||||
|
model_config = yaml.safe_load(open(p, 'r').read())
|
||||||
|
else:
|
||||||
|
model_config = {}
|
||||||
|
del p
|
||||||
|
|
||||||
|
|
||||||
# Load custom model-specific settings
|
# Load custom model-specific settings
|
||||||
user_config = load_user_config()
|
user_config = load_user_config()
|
||||||
|
|
||||||
|
model_config = OrderedDict(model_config)
|
||||||
user_config = OrderedDict(user_config)
|
user_config = OrderedDict(user_config)
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import html
|
||||||
import pprint
|
import pprint
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
@ -17,7 +18,9 @@ from modules.utils import check_model_loaded
|
||||||
|
|
||||||
|
|
||||||
def generate_reply(*args, **kwargs):
|
def generate_reply(*args, **kwargs):
|
||||||
models.load_model_if_idle_unloaded()
|
if shared.args.idle_timeout > 0 and shared.model is None and shared.model_name not in [None, 'None']:
|
||||||
|
from modules.models import load_model
|
||||||
|
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||||
|
|
||||||
state = args[1] if len(args) > 1 else kwargs.get('state', {})
|
state = args[1] if len(args) > 1 else kwargs.get('state', {})
|
||||||
use_parallel = (
|
use_parallel = (
|
||||||
|
|
@ -29,16 +32,10 @@ def generate_reply(*args, **kwargs):
|
||||||
if not use_parallel:
|
if not use_parallel:
|
||||||
shared.generation_lock.acquire()
|
shared.generation_lock.acquire()
|
||||||
|
|
||||||
with models._generation_count_lock:
|
|
||||||
models.active_generation_count += 1
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for result in _generate_reply(*args, **kwargs):
|
for result in _generate_reply(*args, **kwargs):
|
||||||
yield result
|
yield result
|
||||||
finally:
|
finally:
|
||||||
with models._generation_count_lock:
|
|
||||||
models.active_generation_count -= 1
|
|
||||||
|
|
||||||
models.last_generation_time = time.time()
|
models.last_generation_time = time.time()
|
||||||
if not use_parallel:
|
if not use_parallel:
|
||||||
shared.generation_lock.release()
|
shared.generation_lock.release()
|
||||||
|
|
@ -81,13 +78,10 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
||||||
reply = ''
|
reply = ''
|
||||||
is_stream = state['stream']
|
is_stream = state['stream']
|
||||||
if len(all_stop_strings) > 0 and not state['stream']:
|
if len(all_stop_strings) > 0 and not state['stream']:
|
||||||
original_logits_processor = state.get('logits_processor')
|
|
||||||
stop_event_ref = state.pop('stop_event', None)
|
stop_event_ref = state.pop('stop_event', None)
|
||||||
state = copy.deepcopy(state)
|
state = copy.deepcopy(state)
|
||||||
if stop_event_ref is not None:
|
if stop_event_ref is not None:
|
||||||
state['stop_event'] = stop_event_ref
|
state['stop_event'] = stop_event_ref
|
||||||
if original_logits_processor is not None:
|
|
||||||
state['logits_processor'] = original_logits_processor
|
|
||||||
state['stream'] = True
|
state['stream'] = True
|
||||||
|
|
||||||
# Generate
|
# Generate
|
||||||
|
|
@ -129,8 +123,6 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
||||||
|
|
||||||
|
|
||||||
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
||||||
if shared.tokenizer is None:
|
|
||||||
models.load_model_if_idle_unloaded()
|
|
||||||
if shared.tokenizer is None:
|
if shared.tokenizer is None:
|
||||||
raise ValueError('No tokenizer is loaded')
|
raise ValueError('No tokenizer is loaded')
|
||||||
|
|
||||||
|
|
@ -181,8 +173,6 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
||||||
|
|
||||||
|
|
||||||
def decode(output_ids, skip_special_tokens=True):
|
def decode(output_ids, skip_special_tokens=True):
|
||||||
if shared.tokenizer is None:
|
|
||||||
models.load_model_if_idle_unloaded()
|
|
||||||
if shared.tokenizer is None:
|
if shared.tokenizer is None:
|
||||||
raise ValueError('No tokenizer is loaded')
|
raise ValueError('No tokenizer is loaded')
|
||||||
|
|
||||||
|
|
@ -385,7 +375,7 @@ def generate_reply_HF(question, original_question, state, stopping_strings=None,
|
||||||
generate_params['sampler_priority'] = [x.strip() for x in state['sampler_priority'].replace('\n', ',').split(',') if x.strip()]
|
generate_params['sampler_priority'] = [x.strip() for x in state['sampler_priority'].replace('\n', ',').split(',') if x.strip()]
|
||||||
|
|
||||||
if state['custom_token_bans']:
|
if state['custom_token_bans']:
|
||||||
to_ban = [int(x.strip()) for x in state['custom_token_bans'].split(',') if x.strip()]
|
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
|
||||||
if len(to_ban) > 0:
|
if len(to_ban) > 0:
|
||||||
if generate_params.get('suppress_tokens', None):
|
if generate_params.get('suppress_tokens', None):
|
||||||
generate_params['suppress_tokens'] += to_ban
|
generate_params['suppress_tokens'] += to_ban
|
||||||
|
|
@ -484,7 +474,7 @@ def generate_reply_HF(question, original_question, state, stopping_strings=None,
|
||||||
yield cumulative_reply
|
yield cumulative_reply
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to generate reply (HF)")
|
traceback.print_exc()
|
||||||
finally:
|
finally:
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
original_tokens = len(original_input_ids[0])
|
original_tokens = len(original_input_ids[0])
|
||||||
|
|
@ -517,7 +507,7 @@ def generate_reply_custom(question, original_question, state, stopping_strings=N
|
||||||
yield reply
|
yield reply
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to generate reply (custom)")
|
traceback.print_exc()
|
||||||
finally:
|
finally:
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,724 +0,0 @@
|
||||||
import json
|
|
||||||
import random
|
|
||||||
import re
|
|
||||||
|
|
||||||
from modules.reasoning import extract_reasoning
|
|
||||||
|
|
||||||
|
|
||||||
def _make_tool_call(name, arguments):
|
|
||||||
return {"type": "function", "function": {"name": name, "arguments": arguments}}
|
|
||||||
|
|
||||||
|
|
||||||
def get_tool_call_id() -> str:
|
|
||||||
letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
|
||||||
b = [random.choice(letter_bytes) for _ in range(8)]
|
|
||||||
return "call_" + "".join(b).lower()
|
|
||||||
|
|
||||||
|
|
||||||
# All known opening markers for tool calls across model formats.
|
|
||||||
TOOL_CALL_OPENING_MARKERS = [
|
|
||||||
'<tool_call>',
|
|
||||||
'<function_call>',
|
|
||||||
'<minimax:tool_call>',
|
|
||||||
'<|tool_call_begin|>',
|
|
||||||
'<|tool_calls_section_begin|>',
|
|
||||||
'<|tool▁call▁begin|>',
|
|
||||||
'<|tool▁calls▁begin|>',
|
|
||||||
'[TOOL_CALLS]',
|
|
||||||
'to=functions.',
|
|
||||||
'<|channel|>commentary',
|
|
||||||
'<|tool_call>call:',
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def streaming_tool_buffer_check(text, markers=None, tool_names=None, check_bare_names=False, partial_match=True):
|
|
||||||
'''
|
|
||||||
Check whether streaming output should be withheld because it may
|
|
||||||
contain tool-call markup.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Full accumulated internal text.
|
|
||||||
markers: Template-specific markers for partial-prefix matching.
|
|
||||||
If None, falls back to TOOL_CALL_OPENING_MARKERS.
|
|
||||||
tool_names: List of tool function names.
|
|
||||||
check_bare_names: Whether to do partial-prefix matching on tool
|
|
||||||
names (for models with unknown template format).
|
|
||||||
partial_match: Whether to check partial prefixes of markers/names.
|
|
||||||
Set to False for end-of-generation checks where a
|
|
||||||
partial prefix is just normal text, not an incomplete
|
|
||||||
tool call.
|
|
||||||
'''
|
|
||||||
# Strip thinking blocks so tool-call syntax inside <think> doesn't
|
|
||||||
# trigger false positives.
|
|
||||||
_, text = extract_reasoning(text)
|
|
||||||
|
|
||||||
# Full marker found in text → buffer permanently.
|
|
||||||
# Always checks ALL known markers regardless of template (cheap safety net).
|
|
||||||
for marker in TOOL_CALL_OPENING_MARKERS:
|
|
||||||
if marker in text:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Bare function-name full match: "get_weather{...}" or "get_weather {...}"
|
|
||||||
if tool_names:
|
|
||||||
for name in tool_names:
|
|
||||||
if name + '{' in text or name + ' {' in text:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if not partial_match:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Partial-prefix matching: only for template-specific markers.
|
|
||||||
for marker in (markers if markers is not None else TOOL_CALL_OPENING_MARKERS):
|
|
||||||
for prefix_len in range(min(len(marker) - 1, len(text)), 0, -1):
|
|
||||||
if text.endswith(marker[:prefix_len]):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Bare-name partial matching: only when template format is unknown.
|
|
||||||
if check_bare_names and tool_names:
|
|
||||||
for name in tool_names:
|
|
||||||
if text.endswith(name):
|
|
||||||
return True
|
|
||||||
for prefix_len in range(min(len(name) - 1, len(text)), 0, -1):
|
|
||||||
if text.endswith(name[:prefix_len]):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def check_and_sanitize_tool_call_candidate(candidate_dict: dict, tool_names: list[str]):
|
|
||||||
# check if property 'function' exists and is a dictionary, otherwise adapt dict
|
|
||||||
if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str):
|
|
||||||
candidate_dict = {"type": "function", "function": candidate_dict}
|
|
||||||
if 'function' in candidate_dict and isinstance(candidate_dict['function'], str):
|
|
||||||
candidate_dict['name'] = candidate_dict['function']
|
|
||||||
del candidate_dict['function']
|
|
||||||
candidate_dict = {"type": "function", "function": candidate_dict}
|
|
||||||
if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict):
|
|
||||||
# check if 'name' exists within 'function' and is part of known tools
|
|
||||||
if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names:
|
|
||||||
candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value
|
|
||||||
# map property 'parameters' used by some older models to 'arguments'
|
|
||||||
if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]:
|
|
||||||
candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"]
|
|
||||||
del candidate_dict["function"]["parameters"]
|
|
||||||
return candidate_dict
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_balanced_json(text: str, start: int) -> str | None:
|
|
||||||
"""Extract a balanced JSON object from text starting at the given position.
|
|
||||||
|
|
||||||
Walks through the string tracking brace depth and string boundaries
|
|
||||||
to correctly handle arbitrary nesting levels.
|
|
||||||
"""
|
|
||||||
if start >= len(text) or text[start] != '{':
|
|
||||||
return None
|
|
||||||
depth = 0
|
|
||||||
in_string = False
|
|
||||||
escape_next = False
|
|
||||||
for i in range(start, len(text)):
|
|
||||||
c = text[i]
|
|
||||||
if escape_next:
|
|
||||||
escape_next = False
|
|
||||||
continue
|
|
||||||
if c == '\\' and in_string:
|
|
||||||
escape_next = True
|
|
||||||
continue
|
|
||||||
if c == '"':
|
|
||||||
in_string = not in_string
|
|
||||||
continue
|
|
||||||
if in_string:
|
|
||||||
continue
|
|
||||||
if c == '{':
|
|
||||||
depth += 1
|
|
||||||
elif c == '}':
|
|
||||||
depth -= 1
|
|
||||||
if depth == 0:
|
|
||||||
return text[start:i + 1]
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_channel_tool_calls(answer: str, tool_names: list[str]):
|
|
||||||
"""Parse channel-based tool calls used by GPT-OSS and similar models.
|
|
||||||
|
|
||||||
Format:
|
|
||||||
<|start|>assistant to=functions.func_name<|channel|>commentary json<|message|>{"arg": "value"}
|
|
||||||
or:
|
|
||||||
<|channel|>commentary to=functions.func_name <|constrain|>json<|message|>{"arg": "value"}
|
|
||||||
"""
|
|
||||||
matches = []
|
|
||||||
start_pos = None
|
|
||||||
# Pattern 1: to=functions.NAME before <|channel|> (GPT-OSS primary format)
|
|
||||||
# Pattern 2: to=functions.NAME after <|channel|> (alternative format)
|
|
||||||
patterns = [
|
|
||||||
r'to=functions\.([^<\s]+)\s*<\|channel\|>[^<]*<\|message\|>',
|
|
||||||
r'<\|channel\|>\w+ to=functions\.([^<\s]+).*?<\|message\|>',
|
|
||||||
]
|
|
||||||
for pattern in patterns:
|
|
||||||
for m in re.finditer(pattern, answer):
|
|
||||||
func_name = m.group(1).strip()
|
|
||||||
if func_name not in tool_names:
|
|
||||||
continue
|
|
||||||
json_str = _extract_balanced_json(answer, m.end())
|
|
||||||
if json_str is None:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
arguments = json.loads(json_str)
|
|
||||||
if start_pos is None:
|
|
||||||
prefix = answer.rfind('<|start|>assistant', 0, m.start())
|
|
||||||
start_pos = prefix if prefix != -1 else m.start()
|
|
||||||
matches.append(_make_tool_call(func_name, arguments))
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
if matches:
|
|
||||||
break
|
|
||||||
return matches, start_pos
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_mistral_token_tool_calls(answer: str, tool_names: list[str]):
|
|
||||||
"""Parse Mistral/Devstral-style tool calls with [TOOL_CALLS] and [ARGS] special tokens.
|
|
||||||
|
|
||||||
Format:
|
|
||||||
[TOOL_CALLS]func_name[ARGS]{"arg": "value"}
|
|
||||||
"""
|
|
||||||
matches = []
|
|
||||||
start_pos = None
|
|
||||||
for m in re.finditer(
|
|
||||||
r'\[TOOL_CALLS\]\s*(\S+?)\s*\[ARGS\]\s*',
|
|
||||||
answer
|
|
||||||
):
|
|
||||||
func_name = m.group(1).strip()
|
|
||||||
if func_name not in tool_names:
|
|
||||||
continue
|
|
||||||
json_str = _extract_balanced_json(answer, m.end())
|
|
||||||
if json_str is None:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
arguments = json.loads(json_str)
|
|
||||||
if start_pos is None:
|
|
||||||
start_pos = m.start()
|
|
||||||
matches.append(_make_tool_call(func_name, arguments))
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
return matches, start_pos
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_bare_name_tool_calls(answer: str, tool_names: list[str]):
|
|
||||||
"""Parse bare function-name style tool calls used by Mistral and similar models.
|
|
||||||
|
|
||||||
Format:
|
|
||||||
functionName{"arg": "value"}
|
|
||||||
Multiple calls are concatenated directly or separated by whitespace.
|
|
||||||
"""
|
|
||||||
matches = []
|
|
||||||
start_pos = None
|
|
||||||
# Match tool name followed by opening brace, then extract balanced JSON
|
|
||||||
escaped_names = [re.escape(name) for name in tool_names]
|
|
||||||
pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{'
|
|
||||||
for match in re.finditer(pattern, answer):
|
|
||||||
text = match.group(0)
|
|
||||||
name = None
|
|
||||||
for n in tool_names:
|
|
||||||
if text.startswith(n):
|
|
||||||
name = n
|
|
||||||
break
|
|
||||||
if not name:
|
|
||||||
continue
|
|
||||||
brace_start = match.end() - 1
|
|
||||||
json_str = _extract_balanced_json(answer, brace_start)
|
|
||||||
if json_str is None:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
arguments = json.loads(json_str)
|
|
||||||
if start_pos is None:
|
|
||||||
start_pos = match.start()
|
|
||||||
matches.append(_make_tool_call(name, arguments))
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
return matches, start_pos
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_xml_param_tool_calls(answer: str, tool_names: list[str]):
|
|
||||||
"""Parse XML-parameter style tool calls used by Qwen3.5 and similar models.
|
|
||||||
|
|
||||||
Format:
|
|
||||||
<tool_call>
|
|
||||||
<function=function_name>
|
|
||||||
<parameter=param_name>value</parameter>
|
|
||||||
</function>
|
|
||||||
</tool_call>
|
|
||||||
"""
|
|
||||||
matches = []
|
|
||||||
start_pos = None
|
|
||||||
for tc_match in re.finditer(r'<tool_call>\s*(.*?)\s*</tool_call>', answer, re.DOTALL):
|
|
||||||
tc_content = tc_match.group(1)
|
|
||||||
func_match = re.search(r'<function=([^>]+)>', tc_content)
|
|
||||||
if not func_match:
|
|
||||||
continue
|
|
||||||
func_name = func_match.group(1).strip()
|
|
||||||
if func_name not in tool_names:
|
|
||||||
continue
|
|
||||||
arguments = {}
|
|
||||||
for param_match in re.finditer(r'<parameter=([^>]+)>\s*(.*?)\s*</parameter>', tc_content, re.DOTALL):
|
|
||||||
param_name = param_match.group(1).strip()
|
|
||||||
param_value = param_match.group(2).strip()
|
|
||||||
try:
|
|
||||||
param_value = json.loads(param_value)
|
|
||||||
except (json.JSONDecodeError, ValueError):
|
|
||||||
pass # keep as string
|
|
||||||
arguments[param_name] = param_value
|
|
||||||
if start_pos is None:
|
|
||||||
start_pos = tc_match.start()
|
|
||||||
matches.append(_make_tool_call(func_name, arguments))
|
|
||||||
return matches, start_pos
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_kimi_tool_calls(answer: str, tool_names: list[str]):
|
|
||||||
"""Parse Kimi-K2-style tool calls using pipe-delimited tokens.
|
|
||||||
|
|
||||||
Format:
|
|
||||||
<|tool_calls_section_begin|>
|
|
||||||
<|tool_call_begin|>functions.func_name:index<|tool_call_argument_begin|>{"arg": "value"}<|tool_call_end|>
|
|
||||||
<|tool_calls_section_end|>
|
|
||||||
"""
|
|
||||||
matches = []
|
|
||||||
start_pos = None
|
|
||||||
for m in re.finditer(
|
|
||||||
r'<\|tool_call_begin\|>\s*(?:functions\.)?(\S+?)(?::\d+)?\s*<\|tool_call_argument_begin\|>\s*',
|
|
||||||
answer
|
|
||||||
):
|
|
||||||
func_name = m.group(1).strip()
|
|
||||||
if func_name not in tool_names:
|
|
||||||
continue
|
|
||||||
json_str = _extract_balanced_json(answer, m.end())
|
|
||||||
if json_str is None:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
arguments = json.loads(json_str)
|
|
||||||
if start_pos is None:
|
|
||||||
# Check for section begin marker before the call marker
|
|
||||||
section = answer.rfind('<|tool_calls_section_begin|>', 0, m.start())
|
|
||||||
start_pos = section if section != -1 else m.start()
|
|
||||||
matches.append(_make_tool_call(func_name, arguments))
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
return matches, start_pos
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_minimax_tool_calls(answer: str, tool_names: list[str]):
|
|
||||||
"""Parse MiniMax-style tool calls using invoke/parameter XML tags.
|
|
||||||
|
|
||||||
Format:
|
|
||||||
<minimax:tool_call>
|
|
||||||
<invoke name="function_name">
|
|
||||||
<parameter name="param_name">value</parameter>
|
|
||||||
</invoke>
|
|
||||||
</minimax:tool_call>
|
|
||||||
"""
|
|
||||||
matches = []
|
|
||||||
start_pos = None
|
|
||||||
for tc_match in re.finditer(r'<minimax:tool_call>\s*(.*?)\s*</minimax:tool_call>', answer, re.DOTALL):
|
|
||||||
tc_content = tc_match.group(1)
|
|
||||||
# Split on <invoke> to handle multiple parallel calls in one block
|
|
||||||
for invoke_match in re.finditer(r'<invoke\s+name="([^"]+)">(.*?)</invoke>', tc_content, re.DOTALL):
|
|
||||||
func_name = invoke_match.group(1).strip()
|
|
||||||
if func_name not in tool_names:
|
|
||||||
continue
|
|
||||||
invoke_body = invoke_match.group(2)
|
|
||||||
arguments = {}
|
|
||||||
for param_match in re.finditer(r'<parameter\s+name="([^"]+)">\s*(.*?)\s*</parameter>', invoke_body, re.DOTALL):
|
|
||||||
param_name = param_match.group(1).strip()
|
|
||||||
param_value = param_match.group(2).strip()
|
|
||||||
try:
|
|
||||||
param_value = json.loads(param_value)
|
|
||||||
except (json.JSONDecodeError, ValueError):
|
|
||||||
pass # keep as string
|
|
||||||
arguments[param_name] = param_value
|
|
||||||
if start_pos is None:
|
|
||||||
start_pos = tc_match.start()
|
|
||||||
matches.append(_make_tool_call(func_name, arguments))
|
|
||||||
return matches, start_pos
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_deep_seek_tool_calls(answer: str, tool_names: list[str]):
|
|
||||||
"""Parse DeepSeek-style tool calls using fullwidth Unicode token delimiters.
|
|
||||||
|
|
||||||
Format:
|
|
||||||
<|tool▁calls▁begin|><|tool▁call▁begin|>func_name<|tool▁sep|>{"arg": "value"}<|tool▁call▁end|><|tool▁calls▁end|>
|
|
||||||
"""
|
|
||||||
matches = []
|
|
||||||
start_pos = None
|
|
||||||
for m in re.finditer(
|
|
||||||
r'<|tool▁call▁begin|>\s*(\S+?)\s*<|tool▁sep|>\s*',
|
|
||||||
answer
|
|
||||||
):
|
|
||||||
func_name = m.group(1).strip()
|
|
||||||
if func_name not in tool_names:
|
|
||||||
continue
|
|
||||||
json_str = _extract_balanced_json(answer, m.end())
|
|
||||||
if json_str is None:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
arguments = json.loads(json_str)
|
|
||||||
if start_pos is None:
|
|
||||||
# Check for section begin marker before the call marker
|
|
||||||
section = answer.rfind('<|tool▁calls▁begin|>', 0, m.start())
|
|
||||||
start_pos = section if section != -1 else m.start()
|
|
||||||
matches.append(_make_tool_call(func_name, arguments))
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
return matches, start_pos
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_glm_tool_calls(answer: str, tool_names: list[str]):
|
|
||||||
"""Parse GLM-style tool calls using arg_key/arg_value XML pairs.
|
|
||||||
|
|
||||||
Format:
|
|
||||||
<tool_call>function_name
|
|
||||||
<arg_key>key1</arg_key>
|
|
||||||
<arg_value>value1</arg_value>
|
|
||||||
</tool_call>
|
|
||||||
"""
|
|
||||||
matches = []
|
|
||||||
start_pos = None
|
|
||||||
for tc_match in re.finditer(r'<tool_call>\s*(.*?)\s*</tool_call>', answer, re.DOTALL):
|
|
||||||
tc_content = tc_match.group(1)
|
|
||||||
# First non-tag text is the function name
|
|
||||||
name_match = re.match(r'([^<\s]+)', tc_content.strip())
|
|
||||||
if not name_match:
|
|
||||||
continue
|
|
||||||
func_name = name_match.group(1).strip()
|
|
||||||
if func_name not in tool_names:
|
|
||||||
continue
|
|
||||||
# Extract arg_key/arg_value pairs
|
|
||||||
keys = [k.group(1).strip() for k in re.finditer(r'<arg_key>\s*(.*?)\s*</arg_key>', tc_content, re.DOTALL)]
|
|
||||||
vals = [v.group(1).strip() for v in re.finditer(r'<arg_value>\s*(.*?)\s*</arg_value>', tc_content, re.DOTALL)]
|
|
||||||
if len(keys) != len(vals):
|
|
||||||
continue
|
|
||||||
arguments = {}
|
|
||||||
for k, v in zip(keys, vals):
|
|
||||||
try:
|
|
||||||
v = json.loads(v)
|
|
||||||
except (json.JSONDecodeError, ValueError):
|
|
||||||
pass # keep as string
|
|
||||||
arguments[k] = v
|
|
||||||
if start_pos is None:
|
|
||||||
start_pos = tc_match.start()
|
|
||||||
matches.append(_make_tool_call(func_name, arguments))
|
|
||||||
return matches, start_pos
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_gemma4_balanced(text, start):
|
|
||||||
"""Extract balanced braces from Gemma 4 format, using <|"|> as string delimiters."""
|
|
||||||
if start >= len(text) or text[start] != '{':
|
|
||||||
return None
|
|
||||||
depth = 0
|
|
||||||
in_string = False
|
|
||||||
quote_token = '<|"|>'
|
|
||||||
quote_len = len(quote_token)
|
|
||||||
i = start
|
|
||||||
while i < len(text):
|
|
||||||
if text[i:i + quote_len] == quote_token:
|
|
||||||
in_string = not in_string
|
|
||||||
i += quote_len
|
|
||||||
continue
|
|
||||||
if in_string:
|
|
||||||
i += 1
|
|
||||||
continue
|
|
||||||
c = text[i]
|
|
||||||
if c == '{':
|
|
||||||
depth += 1
|
|
||||||
elif c == '}':
|
|
||||||
depth -= 1
|
|
||||||
if depth == 0:
|
|
||||||
return text[start:i + 1]
|
|
||||||
i += 1
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_gemma4_tool_calls(answer: str, tool_names: list[str]):
|
|
||||||
"""Parse Gemma 4-style tool calls.
|
|
||||||
|
|
||||||
Format:
|
|
||||||
<|tool_call>call:func_name{key:<|"|>value<|"|>,...}<tool_call|>
|
|
||||||
|
|
||||||
Values use <|"|> tokens instead of standard JSON quotes, and keys are
|
|
||||||
bare identifiers.
|
|
||||||
"""
|
|
||||||
matches = []
|
|
||||||
start_pos = None
|
|
||||||
|
|
||||||
for m in re.finditer(r'<\|tool_call>call:([^\s{]+)\s*', answer):
|
|
||||||
func_name = m.group(1).strip()
|
|
||||||
if func_name not in tool_names:
|
|
||||||
continue
|
|
||||||
|
|
||||||
brace_start = m.end()
|
|
||||||
if brace_start >= len(answer) or answer[brace_start] != '{':
|
|
||||||
continue
|
|
||||||
|
|
||||||
content = _extract_gemma4_balanced(answer, brace_start)
|
|
||||||
if content is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Convert to JSON: split on <|"|> tokens so that key quoting
|
|
||||||
# only applies outside string values (even-indexed parts),
|
|
||||||
# then rejoin with real quotes.
|
|
||||||
parts = content.split('<|"|>')
|
|
||||||
for idx in range(0, len(parts), 2):
|
|
||||||
parts[idx] = re.sub(r'(^|[{,\[])\s*(\w+)\s*:', r'\1"\2":', parts[idx])
|
|
||||||
json_str = '"'.join(parts)
|
|
||||||
|
|
||||||
try:
|
|
||||||
arguments = json.loads(json_str)
|
|
||||||
if start_pos is None:
|
|
||||||
start_pos = m.start()
|
|
||||||
matches.append(_make_tool_call(func_name, arguments))
|
|
||||||
except (json.JSONDecodeError, ValueError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
return matches, start_pos
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_pythonic_tool_calls(answer: str, tool_names: list[str]):
|
|
||||||
"""Parse pythonic-style tool calls used by Llama 4 and similar models.
|
|
||||||
|
|
||||||
Format:
|
|
||||||
[func_name(param1="value1", param2="value2"), func_name2(...)]
|
|
||||||
"""
|
|
||||||
matches = []
|
|
||||||
start_pos = None
|
|
||||||
# Match a bracketed list of function calls
|
|
||||||
bracket_match = re.search(r'\[([^\[\]]+)\]', answer)
|
|
||||||
if not bracket_match:
|
|
||||||
return matches, start_pos
|
|
||||||
|
|
||||||
inner = bracket_match.group(1)
|
|
||||||
|
|
||||||
# Build pattern for known tool names
|
|
||||||
escaped_names = [re.escape(name) for name in tool_names]
|
|
||||||
name_pattern = '|'.join(escaped_names)
|
|
||||||
|
|
||||||
for call_match in re.finditer(
|
|
||||||
r'(' + name_pattern + r')\(([^)]*)\)',
|
|
||||||
inner
|
|
||||||
):
|
|
||||||
func_name = call_match.group(1)
|
|
||||||
params_str = call_match.group(2).strip()
|
|
||||||
arguments = {}
|
|
||||||
|
|
||||||
if params_str:
|
|
||||||
# Parse key="value" pairs, handling commas inside quoted values
|
|
||||||
for param_match in re.finditer(
|
|
||||||
r'(\w+)\s*=\s*("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\)]+)',
|
|
||||||
params_str
|
|
||||||
):
|
|
||||||
param_name = param_match.group(1)
|
|
||||||
param_value = param_match.group(2).strip()
|
|
||||||
# Strip surrounding quotes
|
|
||||||
if (param_value.startswith('"') and param_value.endswith('"')) or \
|
|
||||||
(param_value.startswith("'") and param_value.endswith("'")):
|
|
||||||
param_value = param_value[1:-1]
|
|
||||||
# Try to parse as JSON for numeric/bool/null values
|
|
||||||
try:
|
|
||||||
param_value = json.loads(param_value)
|
|
||||||
except (json.JSONDecodeError, ValueError):
|
|
||||||
pass
|
|
||||||
arguments[param_name] = param_value
|
|
||||||
|
|
||||||
if start_pos is None:
|
|
||||||
start_pos = bracket_match.start()
|
|
||||||
matches.append(_make_tool_call(func_name, arguments))
|
|
||||||
|
|
||||||
return matches, start_pos
|
|
||||||
|
|
||||||
|
|
||||||
# Format registry: maps template substrings to the parser and streaming
|
|
||||||
# markers for that format. When a format's hints are NOT found in the
|
|
||||||
# template, its parser and markers are excluded.
|
|
||||||
TOOL_CALL_FORMATS = [
|
|
||||||
{
|
|
||||||
'template_hints': ['tool▁call▁begin', 'tool▁calls▁begin'],
|
|
||||||
'parser': _parse_deep_seek_tool_calls,
|
|
||||||
'markers': ['<|tool▁call▁begin|>', '<|tool▁calls▁begin|>'],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'template_hints': ['<|tool_call_begin|>', 'tool_calls_section'],
|
|
||||||
'parser': _parse_kimi_tool_calls,
|
|
||||||
'markers': ['<|tool_call_begin|>', '<|tool_calls_section_begin|>'],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'template_hints': ['to=functions.', '<|channel|>'],
|
|
||||||
'parser': _parse_channel_tool_calls,
|
|
||||||
'markers': ['to=functions.', '<|channel|>commentary'],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'template_hints': ['<|tool_call>call:'],
|
|
||||||
'parser': _parse_gemma4_tool_calls,
|
|
||||||
'markers': ['<|tool_call>call:'],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'template_hints': ['minimax:tool_call'],
|
|
||||||
'parser': _parse_minimax_tool_calls,
|
|
||||||
'markers': ['<minimax:tool_call>'],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'template_hints': ['<arg_key>'],
|
|
||||||
'parser': _parse_glm_tool_calls,
|
|
||||||
'markers': ['<tool_call>'],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'template_hints': ['<tool_call>'],
|
|
||||||
'parser': _parse_xml_param_tool_calls,
|
|
||||||
'markers': ['<tool_call>'],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'template_hints': ['[TOOL_CALLS]'],
|
|
||||||
'parser': _parse_mistral_token_tool_calls,
|
|
||||||
'markers': ['[TOOL_CALLS]'],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'template_hints': ['<function_call>'],
|
|
||||||
'parser': None,
|
|
||||||
'markers': ['<function_call>'],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
# Default ordered list of all specialized parsers.
|
|
||||||
ALL_PARSERS = [
|
|
||||||
_parse_deep_seek_tool_calls,
|
|
||||||
_parse_kimi_tool_calls,
|
|
||||||
_parse_channel_tool_calls,
|
|
||||||
_parse_gemma4_tool_calls,
|
|
||||||
_parse_minimax_tool_calls,
|
|
||||||
_parse_glm_tool_calls,
|
|
||||||
_parse_xml_param_tool_calls,
|
|
||||||
_parse_mistral_token_tool_calls,
|
|
||||||
_parse_bare_name_tool_calls,
|
|
||||||
_parse_pythonic_tool_calls,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def detect_tool_call_format(template_str):
|
|
||||||
"""Inspect a chat/instruction template to determine which tool call
|
|
||||||
formats are relevant.
|
|
||||||
|
|
||||||
Uses an exclude-based approach: starts with all parsers/markers,
|
|
||||||
then removes the ones whose hints are not found in the template.
|
|
||||||
|
|
||||||
Returns (parsers, streaming_markers, check_bare_names).
|
|
||||||
"""
|
|
||||||
if not template_str:
|
|
||||||
return None, TOOL_CALL_OPENING_MARKERS, True
|
|
||||||
|
|
||||||
matched_any = False
|
|
||||||
exclude_parsers = []
|
|
||||||
exclude_markers = []
|
|
||||||
matched_markers = []
|
|
||||||
|
|
||||||
for fmt in TOOL_CALL_FORMATS:
|
|
||||||
if any(hint in template_str for hint in fmt['template_hints']):
|
|
||||||
matched_any = True
|
|
||||||
matched_markers.extend(fmt['markers'])
|
|
||||||
else:
|
|
||||||
if fmt['parser'] is not None:
|
|
||||||
exclude_parsers.append(fmt['parser'])
|
|
||||||
exclude_markers.extend(fmt['markers'])
|
|
||||||
|
|
||||||
if not matched_any:
|
|
||||||
return None, TOOL_CALL_OPENING_MARKERS, True
|
|
||||||
|
|
||||||
parsers = [p for p in ALL_PARSERS if p not in exclude_parsers]
|
|
||||||
markers = [m for m in TOOL_CALL_OPENING_MARKERS if m not in exclude_markers or m in matched_markers]
|
|
||||||
|
|
||||||
return parsers, markers, False
|
|
||||||
|
|
||||||
|
|
||||||
def parse_tool_call(answer: str, tool_names: list[str], return_prefix: bool = False, parsers: list = None):
|
|
||||||
# Strip thinking blocks so tool-call syntax inside <think> is ignored.
|
|
||||||
original_answer = answer
|
|
||||||
_, answer = extract_reasoning(answer)
|
|
||||||
# Reasoning extraction returns empty content when GPT-OSS internal
|
|
||||||
# markup (<|start|>assistant…) follows the thinking block without a
|
|
||||||
# content tag. Fall back to the full text so tool-call markers can
|
|
||||||
# be found.
|
|
||||||
if not answer.strip():
|
|
||||||
answer = original_answer
|
|
||||||
reasoning_offset = 0
|
|
||||||
else:
|
|
||||||
reasoning_offset = len(original_answer) - len(answer)
|
|
||||||
|
|
||||||
matches = []
|
|
||||||
start_pos = None
|
|
||||||
|
|
||||||
def _return(matches, start_pos):
|
|
||||||
if return_prefix:
|
|
||||||
prefix = original_answer[:start_pos + reasoning_offset] if matches and start_pos is not None else ''
|
|
||||||
return matches, prefix
|
|
||||||
return matches
|
|
||||||
|
|
||||||
# Try specialized parsers.
|
|
||||||
for parser in (parsers if parsers is not None else ALL_PARSERS):
|
|
||||||
matches, start_pos = parser(answer, tool_names)
|
|
||||||
if matches:
|
|
||||||
return _return(matches, start_pos)
|
|
||||||
|
|
||||||
# Generic fallback: regex pattern to find the JSON content wrapped in <function>, <tools>, <tool_call>, and other tags observed from various models
|
|
||||||
patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)</\1>"]
|
|
||||||
|
|
||||||
for pattern in patterns:
|
|
||||||
for match in re.finditer(pattern, answer, re.DOTALL):
|
|
||||||
if match.group(2) is None:
|
|
||||||
continue
|
|
||||||
# remove backtick wraps if present
|
|
||||||
candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip())
|
|
||||||
candidate = re.sub(r"```$", "", candidate.strip())
|
|
||||||
# unwrap inner tags
|
|
||||||
candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL)
|
|
||||||
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
|
|
||||||
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
|
|
||||||
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
|
|
||||||
if not candidate.strip().startswith("["):
|
|
||||||
candidate = "[" + candidate + "]"
|
|
||||||
|
|
||||||
candidates = []
|
|
||||||
try:
|
|
||||||
# parse the candidate JSON into a dictionary
|
|
||||||
candidates = json.loads(candidate)
|
|
||||||
if not isinstance(candidates, list):
|
|
||||||
candidates = [candidates]
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
# Ignore invalid JSON silently
|
|
||||||
continue
|
|
||||||
|
|
||||||
for candidate_dict in candidates:
|
|
||||||
checked_candidate = check_and_sanitize_tool_call_candidate(candidate_dict, tool_names)
|
|
||||||
if checked_candidate is not None:
|
|
||||||
if start_pos is None:
|
|
||||||
start_pos = match.start()
|
|
||||||
matches.append(checked_candidate)
|
|
||||||
|
|
||||||
# last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags
|
|
||||||
if len(matches) == 0:
|
|
||||||
try:
|
|
||||||
candidate = answer
|
|
||||||
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
|
|
||||||
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
|
|
||||||
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
|
|
||||||
if not candidate.strip().startswith("["):
|
|
||||||
candidate = "[" + candidate + "]"
|
|
||||||
# parse the candidate JSON into a dictionary
|
|
||||||
candidates = json.loads(candidate)
|
|
||||||
if not isinstance(candidates, list):
|
|
||||||
candidates = [candidates]
|
|
||||||
for candidate_dict in candidates:
|
|
||||||
if not isinstance(candidate_dict, dict):
|
|
||||||
continue
|
|
||||||
checked_candidate = check_and_sanitize_tool_call_candidate(candidate_dict, tool_names)
|
|
||||||
if checked_candidate is not None:
|
|
||||||
matches.append(checked_candidate)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
# Ignore invalid JSON silently
|
|
||||||
pass
|
|
||||||
|
|
||||||
return _return(matches, start_pos)
|
|
||||||
|
|
@ -1,185 +0,0 @@
|
||||||
import asyncio
|
|
||||||
import importlib.util
|
|
||||||
import json
|
|
||||||
|
|
||||||
from modules import shared
|
|
||||||
from modules.logging_colors import logger
|
|
||||||
from modules.utils import natural_keys, sanitize_filename
|
|
||||||
|
|
||||||
|
|
||||||
def get_available_tools():
|
|
||||||
"""Return sorted list of tool script names from user_data/tools/*.py."""
|
|
||||||
tools_dir = shared.user_data_dir / 'tools'
|
|
||||||
tools_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
return sorted((p.stem for p in tools_dir.glob('*.py')), key=natural_keys)
|
|
||||||
|
|
||||||
|
|
||||||
def load_tools(selected_names):
|
|
||||||
"""
|
|
||||||
Import selected tool scripts and return their definitions and executors.
|
|
||||||
Returns (tool_defs, executors) where:
|
|
||||||
- tool_defs: list of OpenAI-format tool dicts
|
|
||||||
- executors: dict mapping function_name -> execute callable
|
|
||||||
"""
|
|
||||||
tool_defs = []
|
|
||||||
executors = {}
|
|
||||||
for name in selected_names:
|
|
||||||
name = sanitize_filename(name)
|
|
||||||
if not name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
path = shared.user_data_dir / 'tools' / f'{name}.py'
|
|
||||||
if not path.exists():
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
spec = importlib.util.spec_from_file_location(f"tool_{name}", str(path))
|
|
||||||
module = importlib.util.module_from_spec(spec)
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
except Exception:
|
|
||||||
logger.exception(f'Failed to load tool script "{name}"')
|
|
||||||
continue
|
|
||||||
|
|
||||||
tool_def = getattr(module, 'tool', None)
|
|
||||||
execute_fn = getattr(module, 'execute', None)
|
|
||||||
if tool_def is None or execute_fn is None:
|
|
||||||
logger.warning(f'Tool "{name}" is missing a "tool" dict or "execute" function.')
|
|
||||||
continue
|
|
||||||
|
|
||||||
func_name = tool_def.get('function', {}).get('name', name)
|
|
||||||
if func_name in executors:
|
|
||||||
logger.warning(f'Tool "{name}" declares function name "{func_name}" which conflicts with an already loaded tool. Skipping.')
|
|
||||||
continue
|
|
||||||
tool_defs.append(tool_def)
|
|
||||||
executors[func_name] = execute_fn
|
|
||||||
|
|
||||||
return tool_defs, executors
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_mcp_servers(servers_str):
|
|
||||||
"""Parse MCP servers textbox: one server per line, format 'url' or 'url,Header: value,Header2: value2'."""
|
|
||||||
servers = []
|
|
||||||
for line in servers_str.strip().splitlines():
|
|
||||||
line = line.strip()
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
parts = line.split(',')
|
|
||||||
url = parts[0].strip()
|
|
||||||
headers = {}
|
|
||||||
for part in parts[1:]:
|
|
||||||
part = part.strip()
|
|
||||||
if ':' in part:
|
|
||||||
key, val = part.split(':', 1)
|
|
||||||
headers[key.strip()] = val.strip()
|
|
||||||
servers.append((url, headers))
|
|
||||||
return servers
|
|
||||||
|
|
||||||
|
|
||||||
def _mcp_tool_to_openai(tool):
|
|
||||||
"""Convert an MCP Tool object to OpenAI-format tool dict."""
|
|
||||||
return {
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": tool.name,
|
|
||||||
"description": tool.description or "",
|
|
||||||
"parameters": tool.inputSchema or {"type": "object", "properties": {}}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def _mcp_session(url, headers, callback):
|
|
||||||
"""Open an MCP session and pass it to the callback."""
|
|
||||||
from mcp.client.streamable_http import streamablehttp_client
|
|
||||||
from mcp import ClientSession
|
|
||||||
|
|
||||||
async with streamablehttp_client(url, headers=headers or None) as (read_stream, write_stream, _):
|
|
||||||
async with ClientSession(read_stream, write_stream) as session:
|
|
||||||
await session.initialize()
|
|
||||||
return await callback(session)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_mcp_executor(name, url, headers):
|
|
||||||
def executor(arguments):
|
|
||||||
return asyncio.run(_call_mcp_tool(name, arguments, url, headers))
|
|
||||||
return executor
|
|
||||||
|
|
||||||
|
|
||||||
async def _connect_mcp_server(url, headers):
|
|
||||||
"""Connect to one MCP server and return (tool_defs, executors)."""
|
|
||||||
|
|
||||||
async def _discover(session):
|
|
||||||
result = await session.list_tools()
|
|
||||||
tool_defs = []
|
|
||||||
executors = {}
|
|
||||||
for tool in result.tools:
|
|
||||||
tool_defs.append(_mcp_tool_to_openai(tool))
|
|
||||||
executors[tool.name] = _make_mcp_executor(tool.name, url, headers)
|
|
||||||
return tool_defs, executors
|
|
||||||
|
|
||||||
return await _mcp_session(url, headers, _discover)
|
|
||||||
|
|
||||||
|
|
||||||
async def _call_mcp_tool(name, arguments, url, headers):
|
|
||||||
"""Connect to an MCP server and call a single tool."""
|
|
||||||
|
|
||||||
async def _invoke(session):
|
|
||||||
result = await session.call_tool(name, arguments)
|
|
||||||
parts = []
|
|
||||||
for content in result.content:
|
|
||||||
if hasattr(content, 'text'):
|
|
||||||
parts.append(content.text)
|
|
||||||
else:
|
|
||||||
parts.append(str(content))
|
|
||||||
return '\n'.join(parts) if parts else ''
|
|
||||||
|
|
||||||
return await _mcp_session(url, headers, _invoke)
|
|
||||||
|
|
||||||
|
|
||||||
async def _connect_all_mcp_servers(servers):
|
|
||||||
"""Connect to all MCP servers concurrently."""
|
|
||||||
results = await asyncio.gather(
|
|
||||||
*(_connect_mcp_server(url, headers) for url, headers in servers),
|
|
||||||
return_exceptions=True
|
|
||||||
)
|
|
||||||
all_defs = []
|
|
||||||
all_executors = {}
|
|
||||||
for (url, _), result in zip(servers, results):
|
|
||||||
if isinstance(result, Exception):
|
|
||||||
logger.exception(f'Failed to connect to MCP server "{url}"', exc_info=result)
|
|
||||||
continue
|
|
||||||
defs, execs = result
|
|
||||||
for td, (fn, ex) in zip(defs, execs.items()):
|
|
||||||
if fn in all_executors:
|
|
||||||
logger.warning(f'MCP tool "{fn}" from {url} conflicts with an already loaded tool. Skipping.')
|
|
||||||
continue
|
|
||||||
all_defs.append(td)
|
|
||||||
all_executors[fn] = ex
|
|
||||||
return all_defs, all_executors
|
|
||||||
|
|
||||||
|
|
||||||
def load_mcp_tools(servers_str):
|
|
||||||
"""
|
|
||||||
Parse MCP servers string and discover tools from each server.
|
|
||||||
Returns (tool_defs, executors) in the same format as load_tools.
|
|
||||||
"""
|
|
||||||
servers = _parse_mcp_servers(servers_str)
|
|
||||||
if not servers:
|
|
||||||
return [], {}
|
|
||||||
|
|
||||||
return asyncio.run(_connect_all_mcp_servers(servers))
|
|
||||||
|
|
||||||
|
|
||||||
def execute_tool(func_name, arguments, executors):
|
|
||||||
"""Execute a tool by function name. Returns result as a JSON string."""
|
|
||||||
fn = executors.get(func_name)
|
|
||||||
if fn is None:
|
|
||||||
return json.dumps({"error": f"Unknown tool: {func_name}"})
|
|
||||||
|
|
||||||
try:
|
|
||||||
if isinstance(arguments, str):
|
|
||||||
arguments = json.loads(arguments)
|
|
||||||
result = fn(arguments)
|
|
||||||
return json.dumps(result) if not isinstance(result, str) else result
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f'Tool "{func_name}" execution failed')
|
|
||||||
return json.dumps({"error": str(e)})
|
|
||||||
|
|
@ -26,7 +26,7 @@ from modules.evaluate import (
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.models import reload_model
|
from modules.models import reload_model
|
||||||
|
|
||||||
PARAMETERS = ["lora_name", "always_override", "all_linear", "q_proj_en", "v_proj_en", "k_proj_en", "o_proj_en", "gate_proj_en", "down_proj_en", "up_proj_en", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "text_dataset", "warmup_steps", "optimizer", "stride_length", "stop_at_loss", "add_eos_token", "excess_length", "report_to", "gradient_checkpointing"]
|
PARAMETERS = ["lora_name", "always_override", "all_linear", "q_proj_en", "v_proj_en", "k_proj_en", "o_proj_en", "gate_proj_en", "down_proj_en", "up_proj_en", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "text_dataset", "higher_rank_limit", "warmup_steps", "optimizer", "stride_length", "stop_at_loss", "add_eos_token", "excess_length", "report_to"]
|
||||||
WANT_INTERRUPT = False
|
WANT_INTERRUPT = False
|
||||||
|
|
||||||
train_log = {}
|
train_log = {}
|
||||||
|
|
@ -52,7 +52,7 @@ def create_ui():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).', elem_classes=['no-background'])
|
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).', elem_classes=['no-background'])
|
||||||
|
|
||||||
with gr.Accordion(label='Target Modules', open=False):
|
with gr.Accordion(label='Target Modules', open=False, elem_classes='tgw-accordion'):
|
||||||
gr.Markdown("Selects which modules to target in training. Targeting more modules is closer to a full fine-tune at the cost of increased VRAM and adapter size.")
|
gr.Markdown("Selects which modules to target in training. Targeting more modules is closer to a full fine-tune at the cost of increased VRAM and adapter size.")
|
||||||
all_linear = gr.Checkbox(label='Target all linear layers', value=True, info='Targets every nn.Linear layer except lm_head. Works for any model architecture. When checked, the individual module checkboxes below are ignored.', elem_classes=['no-background'])
|
all_linear = gr.Checkbox(label='Target all linear layers', value=True, info='Targets every nn.Linear layer except lm_head. Works for any model architecture. When checked, the individual module checkboxes below are ignored.', elem_classes=['no-background'])
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
@ -73,8 +73,8 @@ def create_ui():
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
lora_rank = gr.Slider(label='LoRA Rank', value=8, minimum=0, maximum=2048, step=4, info='Also called dimension count. Use 4–8 for style/format, 128–256 to teach factual knowledge, 1024+ for comprehensive fine-tuning. Very high ranks require significant VRAM.')
|
lora_rank = gr.Slider(label='LoRA Rank', value=8, minimum=0, maximum=1024, step=4, info='Also called dimension count. Higher values = larger file, more content control. Smaller values = smaller file, less control. Use 4 or 8 for style, 128 or 256 to teach, 1024+ for fine-detail on big data. More VRAM is needed for higher ranks.')
|
||||||
lora_alpha = gr.Slider(label='LoRA Alpha', value=16, minimum=0, maximum=4096, step=4, info='This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
|
lora_alpha = gr.Slider(label='LoRA Alpha', value=16, minimum=0, maximum=2048, step=4, info='This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
|
||||||
batch_size = gr.Slider(label='Batch Size', value=32, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.')
|
batch_size = gr.Slider(label='Batch Size', value=32, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.')
|
||||||
micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.')
|
micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.')
|
||||||
cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=4096, value=512, step=32, info='Maximum sequence length in tokens. For instruction datasets, conversations longer than this are dropped. For text datasets, documents are split into chunks of this size. Higher values require more VRAM.')
|
cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=4096, value=512, step=32, info='Maximum sequence length in tokens. For instruction datasets, conversations longer than this are dropped. For text datasets, documents are split into chunks of this size. Higher values require more VRAM.')
|
||||||
|
|
@ -87,18 +87,21 @@ def create_ui():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='cosine', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt'], info='Learning rate scheduler - defines how the learning rate changes over time. "Constant" means never change, "linear" means to go in a straight line from the learning rate down to 0, cosine follows a curve, etc.', elem_classes=['slim-dropdown'])
|
lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='cosine', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt'], info='Learning rate scheduler - defines how the learning rate changes over time. "Constant" means never change, "linear" means to go in a straight line from the learning rate down to 0, cosine follows a curve, etc.', elem_classes=['slim-dropdown'])
|
||||||
|
|
||||||
with gr.Accordion(label='Advanced Options', open=False):
|
with gr.Accordion(label='Advanced Options', open=False, elem_classes='tgw-accordion'):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Optimizer algorithm. adamw_torch is the standard choice. adamw_bnb_8bit uses less VRAM. adafactor is memory-efficient for large models.', elem_classes=['slim-dropdown'])
|
|
||||||
warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate is gradually ramped up from 0 to the target value. This prevents unstable updates early in training.')
|
|
||||||
lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.0, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.')
|
lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.0, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.')
|
||||||
stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)')
|
stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)')
|
||||||
|
with gr.Row():
|
||||||
|
optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Optimizer algorithm. adamw_torch is the standard choice. adamw_bnb_8bit uses less VRAM. adafactor is memory-efficient for large models.', elem_classes=['slim-dropdown'])
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
gradient_checkpointing = gr.Checkbox(label='Gradient checkpointing', value=True, info='Trades ~20-30% training speed for reduced VRAM usage by recomputing activations during the backward pass instead of storing them. No impact on accuracy.')
|
warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate is gradually ramped up from 0 to the target value. This prevents unstable updates early in training.')
|
||||||
|
|
||||||
add_eos_token = gr.Checkbox(label='Add EOS token', value=True, info="Adds EOS token for each document in text datasets.")
|
add_eos_token = gr.Checkbox(label='Add EOS token', value=True, info="Adds EOS token for each document in text datasets.")
|
||||||
excess_length = gr.Dropdown(label='Excess length', value='drop', choices=['drop', 'truncate'], info='What to do with conversations that exceed the cutoff length. "Drop" removes them entirely (recommended). "Truncate" cuts from the right, which may produce incomplete responses.', elem_classes=['slim-dropdown'])
|
excess_length = gr.Dropdown(label='Excess length', value='drop', choices=['drop', 'truncate'], info='What to do with conversations that exceed the cutoff length. "Drop" removes them entirely (recommended). "Truncate" cuts from the right, which may produce incomplete responses.', elem_classes=['slim-dropdown'])
|
||||||
|
|
||||||
|
higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
|
||||||
report_to = gr.Radio(label="Save detailed logs with", value="None", choices=["None", "wandb", "tensorboard"], interactive=True)
|
report_to = gr.Radio(label="Save detailed logs with", value="None", choices=["None", "wandb", "tensorboard"], interactive=True)
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
|
@ -156,12 +159,12 @@ def create_ui():
|
||||||
refresh_table = gr.Button('Refresh the table', elem_classes="small-button", interactive=not mu)
|
refresh_table = gr.Button('Refresh the table', elem_classes="small-button", interactive=not mu)
|
||||||
|
|
||||||
# Training events
|
# Training events
|
||||||
all_params = [lora_name, always_override, all_linear, q_proj_en, v_proj_en, k_proj_en, o_proj_en, gate_proj_en, down_proj_en, up_proj_en, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, text_dataset, warmup_steps, optimizer, stride_length, stop_at_loss, add_eos_token, excess_length, report_to, gradient_checkpointing]
|
all_params = [lora_name, always_override, all_linear, q_proj_en, v_proj_en, k_proj_en, o_proj_en, gate_proj_en, down_proj_en, up_proj_en, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, text_dataset, higher_rank_limit, warmup_steps, optimizer, stride_length, stop_at_loss, add_eos_token, excess_length, report_to]
|
||||||
|
|
||||||
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
|
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
|
||||||
start_button.click(do_train, all_params, output)
|
start_button.click(do_train, all_params, output)
|
||||||
stop_button.click(do_interrupt, None, None, queue=False)
|
stop_button.click(do_interrupt, None, None, queue=False)
|
||||||
|
higher_rank_limit.change(change_rank_limit, [higher_rank_limit], [lora_rank, lora_alpha])
|
||||||
|
|
||||||
# Evaluation events. For some reason, the interrupt event
|
# Evaluation events. For some reason, the interrupt event
|
||||||
# doesn't work with the .then() syntax, so I write them one
|
# doesn't work with the .then() syntax, so I write them one
|
||||||
|
|
@ -206,6 +209,10 @@ def do_copy_params(lora_name: str, *args):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def change_rank_limit(use_higher_ranks: bool):
|
||||||
|
mult = 2 if use_higher_ranks else 1
|
||||||
|
return {"maximum": 1024 * mult, "__type__": "update"}, {"maximum": 2048 * mult, "__type__": "update"}
|
||||||
|
|
||||||
|
|
||||||
def clean_path(base_path: str, path: str):
|
def clean_path(base_path: str, path: str):
|
||||||
"""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
|
"""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
|
||||||
|
|
@ -286,7 +293,7 @@ def calc_trainable_parameters(model):
|
||||||
return trainable_params, all_param
|
return trainable_params, all_param
|
||||||
|
|
||||||
|
|
||||||
def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: bool, v_proj_en: bool, k_proj_en: bool, o_proj_en: bool, gate_proj_en: bool, down_proj_en: bool, up_proj_en: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, text_dataset: str, warmup_steps: int, optimizer: str, stride_length: int, stop_at_loss: float, add_eos_token: bool, excess_length: str, report_to: str, gradient_checkpointing: bool = True):
|
def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: bool, v_proj_en: bool, k_proj_en: bool, o_proj_en: bool, gate_proj_en: bool, down_proj_en: bool, up_proj_en: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, text_dataset: str, higher_rank_limit: bool, warmup_steps: int, optimizer: str, stride_length: int, stop_at_loss: float, add_eos_token: bool, excess_length: str, report_to: str):
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
|
@ -303,11 +310,6 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
||||||
|
|
||||||
# == Input validation / processing ==
|
# == Input validation / processing ==
|
||||||
yield "Preparing the input..."
|
yield "Preparing the input..."
|
||||||
|
|
||||||
if shared.args.loader == 'llama.cpp':
|
|
||||||
yield "Error: LoRA training requires a model loaded with the Transformers loader. GGUF models are not supported for training."
|
|
||||||
return
|
|
||||||
|
|
||||||
lora_file_path = clean_path(None, lora_name)
|
lora_file_path = clean_path(None, lora_name)
|
||||||
if lora_file_path.strip() == '':
|
if lora_file_path.strip() == '':
|
||||||
yield "Missing or invalid LoRA file name input."
|
yield "Missing or invalid LoRA file name input."
|
||||||
|
|
@ -546,8 +548,10 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
||||||
yield f"Failed to load {selected_model}."
|
yield f"Failed to load {selected_model}."
|
||||||
return
|
return
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception('Failed to reload the model.')
|
exc = traceback.format_exc()
|
||||||
yield traceback.format_exc().replace('\n', '\n\n')
|
logger.error('Failed to reload the model.')
|
||||||
|
print(exc)
|
||||||
|
yield exc.replace('\n', '\n\n')
|
||||||
return
|
return
|
||||||
|
|
||||||
# == Start prepping the model itself ==
|
# == Start prepping the model itself ==
|
||||||
|
|
@ -699,7 +703,6 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
||||||
load_best_model_at_end=eval_data is not None,
|
load_best_model_at_end=eval_data is not None,
|
||||||
# TODO: Enable multi-device support
|
# TODO: Enable multi-device support
|
||||||
ddp_find_unused_parameters=None,
|
ddp_find_unused_parameters=None,
|
||||||
gradient_checkpointing=gradient_checkpointing,
|
|
||||||
use_cpu=shared.args.cpu,
|
use_cpu=shared.args.cpu,
|
||||||
remove_unused_columns=False,
|
remove_unused_columns=False,
|
||||||
),
|
),
|
||||||
|
|
@ -732,13 +735,11 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
||||||
if lora_all_param > 0:
|
if lora_all_param > 0:
|
||||||
print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})")
|
print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})")
|
||||||
|
|
||||||
train_log.update({
|
train_log.update({"base_model_name": shared.model_name})
|
||||||
"base_model_name": shared.model_name,
|
train_log.update({"base_model_class": shared.model.__class__.__name__})
|
||||||
"base_model_class": shared.model.__class__.__name__,
|
train_log.update({"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False)})
|
||||||
"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False),
|
train_log.update({"base_loaded_in_8bit": getattr(lora_model, "is_loaded_in_8bit", False)})
|
||||||
"base_loaded_in_8bit": getattr(lora_model, "is_loaded_in_8bit", False),
|
train_log.update({"projections": projections_string})
|
||||||
"projections": projections_string,
|
|
||||||
})
|
|
||||||
|
|
||||||
if stop_at_loss > 0:
|
if stop_at_loss > 0:
|
||||||
print(f"Monitoring loss \033[1;31;1m(Auto-Stop at: {stop_at_loss})\033[0;37;0m")
|
print(f"Monitoring loss \033[1;31;1m(Auto-Stop at: {stop_at_loss})\033[0;37;0m")
|
||||||
|
|
|
||||||
|
|
@ -44,8 +44,8 @@ class Stream(transformers.StoppingCriteria):
|
||||||
|
|
||||||
|
|
||||||
class LogitsBiasProcessor(LogitsProcessor):
|
class LogitsBiasProcessor(LogitsProcessor):
|
||||||
def __init__(self, logit_bias=None):
|
def __init__(self, logit_bias={}):
|
||||||
self.logit_bias = logit_bias if logit_bias is not None else {}
|
self.logit_bias = logit_bias
|
||||||
if self.logit_bias:
|
if self.logit_bias:
|
||||||
self.keys = list([int(key) for key in self.logit_bias.keys()])
|
self.keys = list([int(key) for key in self.logit_bias.keys()])
|
||||||
values = [self.logit_bias[str(key)] for key in self.keys]
|
values = [self.logit_bias[str(key)] for key in self.keys]
|
||||||
|
|
@ -65,16 +65,14 @@ class LogprobProcessor(LogitsProcessor):
|
||||||
def __init__(self, logprobs=None):
|
def __init__(self, logprobs=None):
|
||||||
self.logprobs = logprobs
|
self.logprobs = logprobs
|
||||||
self.token_alternatives = {}
|
self.token_alternatives = {}
|
||||||
self.token_alternatives_history = []
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
if self.logprobs is not None: # 0-5
|
if self.logprobs is not None: # 0-5
|
||||||
log_e_probabilities = F.log_softmax(logits, dim=1)
|
log_e_probabilities = F.log_softmax(logits, dim=1)
|
||||||
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs)
|
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
|
||||||
top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]]
|
top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]]
|
||||||
top_probs = [float(x) for x in top_values[0]]
|
top_probs = [float(x) for x in top_values[0]]
|
||||||
self.token_alternatives = dict(zip(top_tokens, top_probs))
|
self.token_alternatives = dict(zip(top_tokens, top_probs))
|
||||||
self.token_alternatives_history.append(self.token_alternatives)
|
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
@ -109,6 +107,7 @@ def load_model_HF(model_name):
|
||||||
params = {
|
params = {
|
||||||
'low_cpu_mem_usage': True,
|
'low_cpu_mem_usage': True,
|
||||||
'attn_implementation': shared.args.attn_implementation,
|
'attn_implementation': shared.args.attn_implementation,
|
||||||
|
'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16,
|
||||||
}
|
}
|
||||||
|
|
||||||
if shared.original_args.trust_remote_code:
|
if shared.original_args.trust_remote_code:
|
||||||
|
|
@ -119,17 +118,6 @@ def load_model_HF(model_name):
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.original_args.trust_remote_code)
|
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.original_args.trust_remote_code)
|
||||||
|
|
||||||
# Determine torch_dtype: respect --bf16 flag, otherwise autodetect
|
|
||||||
# from model config, but never allow float32.
|
|
||||||
if shared.args.bf16:
|
|
||||||
params['torch_dtype'] = torch.bfloat16
|
|
||||||
else:
|
|
||||||
dtype = getattr(config, 'torch_dtype', None) or getattr(getattr(config, 'text_config', None), 'torch_dtype', None)
|
|
||||||
if dtype in (torch.float16, torch.bfloat16):
|
|
||||||
params['torch_dtype'] = dtype
|
|
||||||
else:
|
|
||||||
params['torch_dtype'] = torch.float16
|
|
||||||
|
|
||||||
if 'chatglm' in model_name.lower():
|
if 'chatglm' in model_name.lower():
|
||||||
LoaderClass = AutoModel
|
LoaderClass = AutoModel
|
||||||
else:
|
else:
|
||||||
|
|
@ -146,6 +134,8 @@ def load_model_HF(model_name):
|
||||||
shared.args.load_in_4bit,
|
shared.args.load_in_4bit,
|
||||||
shared.args.disk,
|
shared.args.disk,
|
||||||
shared.args.cpu_memory is not None,
|
shared.args.cpu_memory is not None,
|
||||||
|
shared.args.compress_pos_emb > 1,
|
||||||
|
shared.args.alpha_value > 1,
|
||||||
])
|
])
|
||||||
|
|
||||||
# Load the model without any special settings
|
# Load the model without any special settings
|
||||||
|
|
@ -208,6 +198,11 @@ def load_model_HF(model_name):
|
||||||
if shared.args.disk:
|
if shared.args.disk:
|
||||||
params['offload_folder'] = str(Path(shared.args.disk_cache_dir))
|
params['offload_folder'] = str(Path(shared.args.disk_cache_dir))
|
||||||
|
|
||||||
|
if shared.args.compress_pos_emb > 1:
|
||||||
|
params['rope_scaling'] = {'type': 'linear', 'factor': shared.args.compress_pos_emb}
|
||||||
|
elif shared.args.alpha_value > 1:
|
||||||
|
params['rope_scaling'] = {'type': 'dynamic', 'factor': shared.args.alpha_value}
|
||||||
|
|
||||||
logger.info("TRANSFORMERS_PARAMS=")
|
logger.info("TRANSFORMERS_PARAMS=")
|
||||||
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params)
|
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params)
|
||||||
print()
|
print()
|
||||||
|
|
|
||||||
101
modules/ui.py
101
modules/ui.py
|
|
@ -66,8 +66,7 @@ theme = gr.themes.Default(
|
||||||
if not shared.args.old_colors:
|
if not shared.args.old_colors:
|
||||||
theme = theme.set(
|
theme = theme.set(
|
||||||
# General Colors
|
# General Colors
|
||||||
border_color_primary='#d2d2d8',
|
border_color_primary='#c5c5d2',
|
||||||
block_border_color='transparent',
|
|
||||||
body_text_color_subdued='#484848',
|
body_text_color_subdued='#484848',
|
||||||
background_fill_secondary='#eaeaea',
|
background_fill_secondary='#eaeaea',
|
||||||
background_fill_secondary_dark='var(--selected-item-color-dark, #282930)',
|
background_fill_secondary_dark='var(--selected-item-color-dark, #282930)',
|
||||||
|
|
@ -75,15 +74,9 @@ if not shared.args.old_colors:
|
||||||
background_fill_primary_dark='var(--darker-gray, #1C1C1D)',
|
background_fill_primary_dark='var(--darker-gray, #1C1C1D)',
|
||||||
body_background_fill="white",
|
body_background_fill="white",
|
||||||
block_background_fill="transparent",
|
block_background_fill="transparent",
|
||||||
body_text_color='#1a1a1a',
|
body_text_color='rgb(64, 64, 64)',
|
||||||
button_secondary_background_fill="white",
|
button_secondary_background_fill="white",
|
||||||
button_secondary_border_color="var(--border-color-primary)",
|
button_secondary_border_color="var(--border-color-primary)",
|
||||||
block_title_text_color='*body_text_color',
|
|
||||||
button_primary_background_fill='#374151',
|
|
||||||
button_primary_background_fill_hover='#4b5563',
|
|
||||||
button_primary_background_fill_hover_dark='rgba(255, 255, 255, 0.05)',
|
|
||||||
button_primary_border_color='#374151',
|
|
||||||
button_primary_text_color='white',
|
|
||||||
input_shadow="none",
|
input_shadow="none",
|
||||||
button_shadow_hover="none",
|
button_shadow_hover="none",
|
||||||
|
|
||||||
|
|
@ -92,11 +85,11 @@ if not shared.args.old_colors:
|
||||||
checkbox_background_color_dark='var(--darker-gray, #1C1C1D)',
|
checkbox_background_color_dark='var(--darker-gray, #1C1C1D)',
|
||||||
block_background_fill_dark='transparent',
|
block_background_fill_dark='transparent',
|
||||||
block_border_color_dark='transparent',
|
block_border_color_dark='transparent',
|
||||||
input_border_color_dark='var(--border-color-dark)',
|
input_border_color_dark='var(--border-color-dark, #525252)',
|
||||||
input_border_color_focus_dark='var(--border-color-dark)',
|
input_border_color_focus_dark='var(--border-color-dark, #525252)',
|
||||||
checkbox_border_color_dark='rgba(255, 255, 255, 0.2)',
|
checkbox_border_color_dark='var(--border-color-dark, #525252)',
|
||||||
border_color_primary_dark='var(--border-color-dark)',
|
border_color_primary_dark='var(--border-color-dark, #525252)',
|
||||||
button_secondary_border_color_dark='var(--border-color-dark)',
|
button_secondary_border_color_dark='var(--border-color-dark, #525252)',
|
||||||
body_background_fill_dark='var(--dark-gray, #212125)',
|
body_background_fill_dark='var(--dark-gray, #212125)',
|
||||||
button_primary_background_fill_dark='transparent',
|
button_primary_background_fill_dark='transparent',
|
||||||
button_secondary_background_fill_dark='transparent',
|
button_secondary_background_fill_dark='transparent',
|
||||||
|
|
@ -114,12 +107,10 @@ if not shared.args.old_colors:
|
||||||
block_shadow_dark='none',
|
block_shadow_dark='none',
|
||||||
input_shadow_focus='none',
|
input_shadow_focus='none',
|
||||||
input_shadow_focus_dark='none',
|
input_shadow_focus_dark='none',
|
||||||
button_large_radius='0.75rem',
|
button_large_radius='0.375rem',
|
||||||
button_small_radius='0.75rem',
|
|
||||||
button_large_padding='6px 12px',
|
button_large_padding='6px 12px',
|
||||||
input_radius='0.5rem',
|
input_radius='0.375rem',
|
||||||
block_radius='0.375rem',
|
block_radius='0',
|
||||||
button_transition='background-color 0.15s ease, border-color 0.15s ease, color 0.15s ease',
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if (shared.user_data_dir / "notification.mp3").exists():
|
if (shared.user_data_dir / "notification.mp3").exists():
|
||||||
|
|
@ -129,8 +120,58 @@ else:
|
||||||
|
|
||||||
|
|
||||||
def list_model_elements():
|
def list_model_elements():
|
||||||
from modules.loaders import list_model_elements
|
elements = [
|
||||||
return list_model_elements()
|
'filter_by_loader',
|
||||||
|
'loader',
|
||||||
|
'cpu_memory',
|
||||||
|
'gpu_layers',
|
||||||
|
'fit_target',
|
||||||
|
'cpu_moe',
|
||||||
|
'threads',
|
||||||
|
'threads_batch',
|
||||||
|
'batch_size',
|
||||||
|
'ubatch_size',
|
||||||
|
'ctx_size',
|
||||||
|
'cache_type',
|
||||||
|
'tensor_split',
|
||||||
|
'extra_flags',
|
||||||
|
'streaming_llm',
|
||||||
|
'gpu_split',
|
||||||
|
'alpha_value',
|
||||||
|
'rope_freq_base',
|
||||||
|
'compress_pos_emb',
|
||||||
|
'compute_dtype',
|
||||||
|
'quant_type',
|
||||||
|
'load_in_8bit',
|
||||||
|
'load_in_4bit',
|
||||||
|
'attn_implementation',
|
||||||
|
'cpu',
|
||||||
|
'disk',
|
||||||
|
'row_split',
|
||||||
|
'no_kv_offload',
|
||||||
|
'no_mmap',
|
||||||
|
'mlock',
|
||||||
|
'numa',
|
||||||
|
'parallel',
|
||||||
|
'use_double_quant',
|
||||||
|
'bf16',
|
||||||
|
'enable_tp',
|
||||||
|
'tp_backend',
|
||||||
|
'cfg_cache',
|
||||||
|
'no_use_fast',
|
||||||
|
'model_draft',
|
||||||
|
'draft_max',
|
||||||
|
'gpu_layers_draft',
|
||||||
|
'device_draft',
|
||||||
|
'ctx_size_draft',
|
||||||
|
'spec_type',
|
||||||
|
'spec_ngram_size_n',
|
||||||
|
'spec_ngram_size_m',
|
||||||
|
'spec_ngram_min_hits',
|
||||||
|
'mmproj',
|
||||||
|
]
|
||||||
|
|
||||||
|
return elements
|
||||||
|
|
||||||
|
|
||||||
def list_interface_input_elements():
|
def list_interface_input_elements():
|
||||||
|
|
@ -208,8 +249,6 @@ def list_interface_input_elements():
|
||||||
'unique_id',
|
'unique_id',
|
||||||
'textbox',
|
'textbox',
|
||||||
'start_with',
|
'start_with',
|
||||||
'selected_tools',
|
|
||||||
'mcp_servers',
|
|
||||||
'mode',
|
'mode',
|
||||||
'chat_style',
|
'chat_style',
|
||||||
'chat-instruct_command',
|
'chat-instruct_command',
|
||||||
|
|
@ -301,7 +340,7 @@ def apply_interface_values(state, use_persistent=False):
|
||||||
|
|
||||||
elements = list_interface_input_elements()
|
elements = list_interface_input_elements()
|
||||||
|
|
||||||
if not state:
|
if len(state) == 0:
|
||||||
return [gr.update() for k in elements] # Dummy, do nothing
|
return [gr.update() for k in elements] # Dummy, do nothing
|
||||||
else:
|
else:
|
||||||
return [state[k] if k in state else gr.update() for k in elements]
|
return [state[k] if k in state else gr.update() for k in elements]
|
||||||
|
|
@ -309,22 +348,19 @@ def apply_interface_values(state, use_persistent=False):
|
||||||
|
|
||||||
def save_settings(state, preset, extensions_list, show_controls, theme_state, manual_save=False):
|
def save_settings(state, preset, extensions_list, show_controls, theme_state, manual_save=False):
|
||||||
output = copy.deepcopy(shared.settings)
|
output = copy.deepcopy(shared.settings)
|
||||||
|
exclude = []
|
||||||
for k in state:
|
for k in state:
|
||||||
if k in shared.settings:
|
if k in shared.settings and k not in exclude:
|
||||||
output[k] = state[k]
|
output[k] = state[k]
|
||||||
|
|
||||||
if preset:
|
|
||||||
output['preset'] = preset
|
output['preset'] = preset
|
||||||
output['prompt-notebook'] = state['prompt_menu-default'] if state['show_two_notebook_columns'] else state['prompt_menu-notebook']
|
output['prompt-notebook'] = state['prompt_menu-default'] if state['show_two_notebook_columns'] else state['prompt_menu-notebook']
|
||||||
if state.get('character_menu'):
|
|
||||||
output['character'] = state['character_menu']
|
output['character'] = state['character_menu']
|
||||||
if state.get('user_menu'):
|
if 'user_menu' in state and state['user_menu']:
|
||||||
output['user'] = state['user_menu']
|
output['user'] = state['user_menu']
|
||||||
output['seed'] = int(output['seed'])
|
output['seed'] = int(output['seed'])
|
||||||
output['custom_stopping_strings'] = output.get('custom_stopping_strings') or ''
|
|
||||||
output['custom_token_bans'] = output.get('custom_token_bans') or ''
|
|
||||||
output['show_controls'] = show_controls
|
output['show_controls'] = show_controls
|
||||||
output['dark_theme'] = theme_state == 'dark'
|
output['dark_theme'] = True if theme_state == 'dark' else False
|
||||||
output.pop('instruction_template_str')
|
output.pop('instruction_template_str')
|
||||||
output.pop('truncation_length')
|
output.pop('truncation_length')
|
||||||
|
|
||||||
|
|
@ -434,8 +470,6 @@ def setup_auto_save():
|
||||||
'user_bio',
|
'user_bio',
|
||||||
'custom_system_message',
|
'custom_system_message',
|
||||||
'chat_template_str',
|
'chat_template_str',
|
||||||
'selected_tools',
|
|
||||||
'mcp_servers',
|
|
||||||
|
|
||||||
# Parameters tab (ui_parameters.py) - Generation parameters
|
# Parameters tab (ui_parameters.py) - Generation parameters
|
||||||
'preset_menu',
|
'preset_menu',
|
||||||
|
|
@ -486,6 +520,7 @@ def setup_auto_save():
|
||||||
'skip_special_tokens',
|
'skip_special_tokens',
|
||||||
'stream',
|
'stream',
|
||||||
'static_cache',
|
'static_cache',
|
||||||
|
'truncation_length',
|
||||||
'seed',
|
'seed',
|
||||||
'sampler_priority',
|
'sampler_priority',
|
||||||
'custom_stopping_strings',
|
'custom_stopping_strings',
|
||||||
|
|
|
||||||
|
|
@ -28,8 +28,7 @@ def create_ui():
|
||||||
shared.gradio['branch_chat'] = gr.Button('Branch', elem_classes=['refresh-button', 'refresh-button-medium'], elem_id='Branch', interactive=not mu)
|
shared.gradio['branch_chat'] = gr.Button('Branch', elem_classes=['refresh-button', 'refresh-button-medium'], elem_id='Branch', interactive=not mu)
|
||||||
shared.gradio['rename_chat'] = gr.Button('Rename', elem_classes=['refresh-button', 'refresh-button-medium'], interactive=not mu)
|
shared.gradio['rename_chat'] = gr.Button('Rename', elem_classes=['refresh-button', 'refresh-button-medium'], interactive=not mu)
|
||||||
shared.gradio['delete_chat'] = gr.Button('🗑️', visible=False, elem_classes='refresh-button', interactive=not mu, elem_id='delete_chat')
|
shared.gradio['delete_chat'] = gr.Button('🗑️', visible=False, elem_classes='refresh-button', interactive=not mu, elem_id='delete_chat')
|
||||||
shared.gradio['Start new chat'] = gr.Button('New chat', elem_classes=['refresh-button', 'refresh-button-medium', 'focus-on-chat-input'], elem_id='new-chat-btn')
|
shared.gradio['Start new chat'] = gr.Button('New chat', elem_classes=['refresh-button', 'refresh-button-medium', 'focus-on-chat-input'])
|
||||||
shared.gradio['Start incognito chat'] = gr.Button('Incognito chat', visible=False, elem_id='incognito-chat-btn')
|
|
||||||
shared.gradio['branch_index'] = gr.Number(value=-1, precision=0, visible=False, elem_id="Branch-index", interactive=True)
|
shared.gradio['branch_index'] = gr.Number(value=-1, precision=0, visible=False, elem_id="Branch-index", interactive=True)
|
||||||
|
|
||||||
shared.gradio['search_chat'] = gr.Textbox(placeholder='Search chats...', max_lines=1, elem_id='search_chat')
|
shared.gradio['search_chat'] = gr.Textbox(placeholder='Search chats...', max_lines=1, elem_id='search_chat')
|
||||||
|
|
@ -52,7 +51,7 @@ def create_ui():
|
||||||
shared.gradio['html_display'] = gr.HTML(value=chat_html_wrapper({'internal': [], 'visible': [], 'metadata': {}}, '', '', 'chat', 'cai-chat', '')['html'], visible=True)
|
shared.gradio['html_display'] = gr.HTML(value=chat_html_wrapper({'internal': [], 'visible': [], 'metadata': {}}, '', '', 'chat', 'cai-chat', '')['html'], visible=True)
|
||||||
with gr.Row(elem_id="chat-input-row"):
|
with gr.Row(elem_id="chat-input-row"):
|
||||||
with gr.Column(scale=1, elem_id='gr-hover-container'):
|
with gr.Column(scale=1, elem_id='gr-hover-container'):
|
||||||
gr.HTML(value='<div class="hover-element" onclick="void(0)"><span id="hover-element-button"><svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><line x1="4" y1="6" x2="20" y2="6"></line><line x1="4" y1="12" x2="20" y2="12"></line><line x1="4" y1="18" x2="20" y2="18"></line></svg></span><div class="hover-menu" id="hover-menu"></div></div>', elem_id='gr-hover')
|
gr.HTML(value='<div class="hover-element" onclick="void(0)"><span style="width: 100px; display: block" id="hover-element-button">☰</span><div class="hover-menu" id="hover-menu"></div>', elem_id='gr-hover')
|
||||||
|
|
||||||
with gr.Column(scale=10, elem_id='chat-input-container'):
|
with gr.Column(scale=10, elem_id='chat-input-container'):
|
||||||
shared.gradio['textbox'] = gr.MultimodalTextbox(label='', placeholder='Send a message', file_types=['text', '.pdf', 'image'], file_count="multiple", elem_id='chat-input', elem_classes=['add_scrollbar'])
|
shared.gradio['textbox'] = gr.MultimodalTextbox(label='', placeholder='Send a message', file_types=['text', '.pdf', 'image'], file_count="multiple", elem_id='chat-input', elem_classes=['add_scrollbar'])
|
||||||
|
|
@ -82,7 +81,7 @@ def create_ui():
|
||||||
gr.HTML("<div class='sidebar-vertical-separator'></div>")
|
gr.HTML("<div class='sidebar-vertical-separator'></div>")
|
||||||
|
|
||||||
shared.gradio['reasoning_effort'] = gr.Dropdown(value=shared.settings['reasoning_effort'], choices=['low', 'medium', 'high'], label='Reasoning effort', info='Used by GPT-OSS.')
|
shared.gradio['reasoning_effort'] = gr.Dropdown(value=shared.settings['reasoning_effort'], choices=['low', 'medium', 'high'], label='Reasoning effort', info='Used by GPT-OSS.')
|
||||||
shared.gradio['enable_thinking'] = gr.Checkbox(value=shared.settings['enable_thinking'], label='Enable thinking', info='For models with thinking support.')
|
shared.gradio['enable_thinking'] = gr.Checkbox(value=shared.settings['enable_thinking'], label='Enable thinking', info='Used by Seed-OSS and pre-2507 Qwen3.')
|
||||||
|
|
||||||
gr.HTML("<div class='sidebar-vertical-separator'></div>")
|
gr.HTML("<div class='sidebar-vertical-separator'></div>")
|
||||||
|
|
||||||
|
|
@ -92,24 +91,6 @@ def create_ui():
|
||||||
|
|
||||||
gr.HTML("<div class='sidebar-vertical-separator'></div>")
|
gr.HTML("<div class='sidebar-vertical-separator'></div>")
|
||||||
|
|
||||||
from modules.tool_use import get_available_tools
|
|
||||||
shared.gradio['selected_tools'] = gr.CheckboxGroup(choices=get_available_tools(), value=shared.settings.get('selected_tools', []), label='Tools', info='Functions the model can call during generation.', elem_id='tools-group')
|
|
||||||
shared.gradio['tools_refresh'] = gr.Button('Refresh list', elem_id='tools-refresh-btn', visible=False)
|
|
||||||
shared.gradio['tools_refresh'].click(fn=lambda: gr.update(choices=get_available_tools()), inputs=[], outputs=[shared.gradio['selected_tools']])
|
|
||||||
|
|
||||||
def sync_web_tools(selected):
|
|
||||||
if 'web_search' in selected and 'fetch_webpage' not in selected and 'fetch_webpage' in get_available_tools():
|
|
||||||
selected.append('fetch_webpage')
|
|
||||||
|
|
||||||
return gr.update(value=selected)
|
|
||||||
|
|
||||||
shared.gradio['selected_tools'].change(fn=sync_web_tools, inputs=[shared.gradio['selected_tools']], outputs=[shared.gradio['selected_tools']], show_progress=False)
|
|
||||||
|
|
||||||
with gr.Accordion('MCP servers', open=False):
|
|
||||||
shared.gradio['mcp_servers'] = gr.Textbox(value=shared.settings.get('mcp_servers', ''), lines=3, max_lines=3, label='', info='One url per line. For headers, write url,Header: value,Header2: value2', elem_classes=['add_scrollbar'])
|
|
||||||
|
|
||||||
gr.HTML("<div class='sidebar-vertical-separator'></div>")
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['mode'] = gr.Radio(choices=['instruct', 'chat-instruct', 'chat'], value=None, label='Mode', info='In instruct and chat-instruct modes, the template under Parameters > Instruction template is used.', elem_id='chat-mode')
|
shared.gradio['mode'] = gr.Radio(choices=['instruct', 'chat-instruct', 'chat'], value=None, label='Mode', info='In instruct and chat-instruct modes, the template under Parameters > Instruction template is used.', elem_id='chat-mode')
|
||||||
|
|
||||||
|
|
@ -294,10 +275,6 @@ def create_event_handlers():
|
||||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
chat.handle_start_new_chat_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False)
|
chat.handle_start_new_chat_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False)
|
||||||
|
|
||||||
shared.gradio['Start incognito chat'].click(
|
|
||||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
|
||||||
chat.handle_start_incognito_chat_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False)
|
|
||||||
|
|
||||||
shared.gradio['delete_chat-confirm'].click(
|
shared.gradio['delete_chat-confirm'].click(
|
||||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
chat.handle_delete_chat_confirm_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False)
|
chat.handle_delete_chat_confirm_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False)
|
||||||
|
|
@ -353,13 +330,13 @@ def create_event_handlers():
|
||||||
shared.gradio['load_template'].click(chat.handle_load_template_click, gradio('instruction_template'), gradio('instruction_template_str', 'instruction_template'), show_progress=False)
|
shared.gradio['load_template'].click(chat.handle_load_template_click, gradio('instruction_template'), gradio('instruction_template_str', 'instruction_template'), show_progress=False)
|
||||||
shared.gradio['save_template'].click(
|
shared.gradio['save_template'].click(
|
||||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
chat.handle_save_template_click, gradio('instruction_template_str'), gradio('save_filename', 'save_root', 'save_contents', 'save_root_state', 'file_saver'), show_progress=False)
|
chat.handle_save_template_click, gradio('instruction_template_str'), gradio('save_filename', 'save_root', 'save_contents', 'file_saver'), show_progress=False)
|
||||||
|
|
||||||
shared.gradio['restore_character'].click(
|
shared.gradio['restore_character'].click(
|
||||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
chat.restore_character_for_ui, gradio('interface_state'), gradio('interface_state', 'name2', 'context', 'greeting', 'character_picture'), show_progress=False)
|
chat.restore_character_for_ui, gradio('interface_state'), gradio('interface_state', 'name2', 'context', 'greeting', 'character_picture'), show_progress=False)
|
||||||
|
|
||||||
shared.gradio['delete_template'].click(chat.handle_delete_template_click, gradio('instruction_template'), gradio('delete_filename', 'delete_root', 'delete_root_state', 'file_deleter'), show_progress=False)
|
shared.gradio['delete_template'].click(chat.handle_delete_template_click, gradio('instruction_template'), gradio('delete_filename', 'delete_root', 'file_deleter'), show_progress=False)
|
||||||
shared.gradio['save_chat_history'].click(
|
shared.gradio['save_chat_history'].click(
|
||||||
lambda x: json.dumps(x, indent=4), gradio('history'), gradio('temporary_text')).then(
|
lambda x: json.dumps(x, indent=4), gradio('history'), gradio('temporary_text')).then(
|
||||||
None, gradio('temporary_text', 'character_menu', 'mode'), None, js=f'(hist, char, mode) => {{{ui.save_files_js}; saveHistory(hist, char, mode)}}')
|
None, gradio('temporary_text', 'character_menu', 'mode'), None, js=f'(hist, char, mode) => {{{ui.save_files_js}; saveHistory(hist, char, mode)}}')
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from modules.text_generation import (
|
||||||
stop_everything_event
|
stop_everything_event
|
||||||
)
|
)
|
||||||
from modules.ui_notebook import store_notebook_state_and_debounce
|
from modules.ui_notebook import store_notebook_state_and_debounce
|
||||||
from modules.utils import gradio, sanitize_filename
|
from modules.utils import gradio
|
||||||
|
|
||||||
inputs = ('textbox-default', 'interface_state')
|
inputs = ('textbox-default', 'interface_state')
|
||||||
outputs = ('output_textbox', 'html-default')
|
outputs = ('output_textbox', 'html-default')
|
||||||
|
|
@ -167,7 +167,6 @@ def handle_new_prompt():
|
||||||
|
|
||||||
|
|
||||||
def handle_delete_prompt_confirm_default(prompt_name):
|
def handle_delete_prompt_confirm_default(prompt_name):
|
||||||
prompt_name = sanitize_filename(prompt_name)
|
|
||||||
available_prompts = utils.get_available_prompts()
|
available_prompts = utils.get_available_prompts()
|
||||||
current_index = available_prompts.index(prompt_name) if prompt_name in available_prompts else 0
|
current_index = available_prompts.index(prompt_name) if prompt_name in available_prompts else 0
|
||||||
|
|
||||||
|
|
@ -200,8 +199,6 @@ def handle_rename_prompt_click_default(current_name):
|
||||||
|
|
||||||
|
|
||||||
def handle_rename_prompt_confirm_default(new_name, current_name):
|
def handle_rename_prompt_confirm_default(new_name, current_name):
|
||||||
new_name = sanitize_filename(new_name)
|
|
||||||
current_name = sanitize_filename(current_name)
|
|
||||||
old_path = shared.user_data_dir / "logs" / "notebook" / f"{current_name}.txt"
|
old_path = shared.user_data_dir / "logs" / "notebook" / f"{current_name}.txt"
|
||||||
new_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt"
|
new_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,14 @@
|
||||||
|
import traceback
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import chat, presets, shared, ui, utils
|
from modules import chat, presets, shared, ui, utils
|
||||||
from modules.logging_colors import logger
|
|
||||||
from modules.utils import gradio, sanitize_filename
|
from modules.utils import gradio, sanitize_filename
|
||||||
|
|
||||||
|
|
||||||
def create_ui():
|
def create_ui():
|
||||||
mu = shared.args.multi_user
|
mu = shared.args.multi_user
|
||||||
|
|
||||||
# Server-side per-session root paths for the generic file saver/deleter.
|
|
||||||
# Set by the handler that opens the dialog, read by the confirm handler.
|
|
||||||
# Using gr.State so they are session-scoped and safe for multi-user.
|
|
||||||
shared.gradio['save_root_state'] = gr.State(None)
|
|
||||||
shared.gradio['delete_root_state'] = gr.State(None)
|
|
||||||
|
|
||||||
# Text file saver
|
# Text file saver
|
||||||
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['file_saver']:
|
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['file_saver']:
|
||||||
shared.gradio['save_filename'] = gr.Textbox(lines=1, label='File name')
|
shared.gradio['save_filename'] = gr.Textbox(lines=1, label='File name')
|
||||||
|
|
@ -71,13 +66,13 @@ def create_event_handlers():
|
||||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
handle_save_preset_click, gradio('interface_state'), gradio('save_preset_contents', 'save_preset_filename', 'preset_saver'), show_progress=False)
|
handle_save_preset_click, gradio('interface_state'), gradio('save_preset_contents', 'save_preset_filename', 'preset_saver'), show_progress=False)
|
||||||
|
|
||||||
shared.gradio['delete_preset'].click(handle_delete_preset_click, gradio('preset_menu'), gradio('delete_filename', 'delete_root', 'delete_root_state', 'file_deleter'), show_progress=False)
|
shared.gradio['delete_preset'].click(handle_delete_preset_click, gradio('preset_menu'), gradio('delete_filename', 'delete_root', 'file_deleter'), show_progress=False)
|
||||||
shared.gradio['save_grammar'].click(handle_save_grammar_click, gradio('grammar_string'), gradio('save_contents', 'save_filename', 'save_root', 'save_root_state', 'file_saver'), show_progress=False)
|
shared.gradio['save_grammar'].click(handle_save_grammar_click, gradio('grammar_string'), gradio('save_contents', 'save_filename', 'save_root', 'file_saver'), show_progress=False)
|
||||||
shared.gradio['delete_grammar'].click(handle_delete_grammar_click, gradio('grammar_file'), gradio('delete_filename', 'delete_root', 'delete_root_state', 'file_deleter'), show_progress=False)
|
shared.gradio['delete_grammar'].click(handle_delete_grammar_click, gradio('grammar_file'), gradio('delete_filename', 'delete_root', 'file_deleter'), show_progress=False)
|
||||||
|
|
||||||
shared.gradio['save_preset_confirm'].click(handle_save_preset_confirm_click, gradio('save_preset_filename', 'save_preset_contents'), gradio('preset_menu', 'preset_saver'), show_progress=False)
|
shared.gradio['save_preset_confirm'].click(handle_save_preset_confirm_click, gradio('save_preset_filename', 'save_preset_contents'), gradio('preset_menu', 'preset_saver'), show_progress=False)
|
||||||
shared.gradio['save_confirm'].click(handle_save_confirm_click, gradio('save_root_state', 'save_filename', 'save_contents'), gradio('save_root_state', 'file_saver'), show_progress=False)
|
shared.gradio['save_confirm'].click(handle_save_confirm_click, gradio('save_root', 'save_filename', 'save_contents'), gradio('file_saver'), show_progress=False)
|
||||||
shared.gradio['delete_confirm'].click(handle_delete_confirm_click, gradio('delete_root_state', 'delete_filename'), gradio('delete_root_state', 'file_deleter'), show_progress=False)
|
shared.gradio['delete_confirm'].click(handle_delete_confirm_click, gradio('delete_root', 'delete_filename'), gradio('file_deleter'), show_progress=False)
|
||||||
shared.gradio['save_character_confirm'].click(handle_save_character_confirm_click, gradio('name2', 'greeting', 'context', 'character_picture', 'save_character_filename'), gradio('character_menu', 'character_saver'), show_progress=False)
|
shared.gradio['save_character_confirm'].click(handle_save_character_confirm_click, gradio('name2', 'greeting', 'context', 'character_picture', 'save_character_filename'), gradio('character_menu', 'character_saver'), show_progress=False)
|
||||||
shared.gradio['delete_character_confirm'].click(handle_delete_character_confirm_click, gradio('character_menu'), gradio('character_menu', 'character_deleter'), show_progress=False)
|
shared.gradio['delete_character_confirm'].click(handle_delete_character_confirm_click, gradio('character_menu'), gradio('character_menu', 'character_deleter'), show_progress=False)
|
||||||
|
|
||||||
|
|
@ -102,7 +97,7 @@ def handle_save_preset_confirm_click(filename, contents):
|
||||||
output = gr.update(choices=available_presets, value=filename)
|
output = gr.update(choices=available_presets, value=filename)
|
||||||
except Exception:
|
except Exception:
|
||||||
output = gr.update()
|
output = gr.update()
|
||||||
logger.exception("Failed to save preset")
|
traceback.print_exc()
|
||||||
|
|
||||||
return [
|
return [
|
||||||
output,
|
output,
|
||||||
|
|
@ -110,30 +105,24 @@ def handle_save_preset_confirm_click(filename, contents):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def handle_save_confirm_click(root_state, filename, contents):
|
def handle_save_confirm_click(root, filename, contents):
|
||||||
try:
|
try:
|
||||||
if root_state is None:
|
|
||||||
return None, gr.update(visible=False)
|
|
||||||
|
|
||||||
filename = sanitize_filename(filename)
|
filename = sanitize_filename(filename)
|
||||||
utils.save_file(root_state + filename, contents)
|
utils.save_file(root + filename, contents)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to save file")
|
traceback.print_exc()
|
||||||
|
|
||||||
return None, gr.update(visible=False)
|
return gr.update(visible=False)
|
||||||
|
|
||||||
|
|
||||||
def handle_delete_confirm_click(root_state, filename):
|
def handle_delete_confirm_click(root, filename):
|
||||||
try:
|
try:
|
||||||
if root_state is None:
|
|
||||||
return None, gr.update(visible=False)
|
|
||||||
|
|
||||||
filename = sanitize_filename(filename)
|
filename = sanitize_filename(filename)
|
||||||
utils.delete_file(root_state + filename)
|
utils.delete_file(root + filename)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to delete file")
|
traceback.print_exc()
|
||||||
|
|
||||||
return None, gr.update(visible=False)
|
return gr.update(visible=False)
|
||||||
|
|
||||||
|
|
||||||
def handle_save_character_confirm_click(name2, greeting, context, character_picture, filename):
|
def handle_save_character_confirm_click(name2, greeting, context, character_picture, filename):
|
||||||
|
|
@ -143,7 +132,7 @@ def handle_save_character_confirm_click(name2, greeting, context, character_pict
|
||||||
output = gr.update(choices=available_characters, value=filename)
|
output = gr.update(choices=available_characters, value=filename)
|
||||||
except Exception:
|
except Exception:
|
||||||
output = gr.update()
|
output = gr.update()
|
||||||
logger.exception("Failed to save character")
|
traceback.print_exc()
|
||||||
|
|
||||||
return [
|
return [
|
||||||
output,
|
output,
|
||||||
|
|
@ -158,7 +147,7 @@ def handle_delete_character_confirm_click(character):
|
||||||
output = chat.update_character_menu_after_deletion(index)
|
output = chat.update_character_menu_after_deletion(index)
|
||||||
except Exception:
|
except Exception:
|
||||||
output = gr.update()
|
output = gr.update()
|
||||||
logger.exception("Failed to delete character")
|
traceback.print_exc()
|
||||||
|
|
||||||
return [
|
return [
|
||||||
output,
|
output,
|
||||||
|
|
@ -176,32 +165,26 @@ def handle_save_preset_click(state):
|
||||||
|
|
||||||
|
|
||||||
def handle_delete_preset_click(preset):
|
def handle_delete_preset_click(preset):
|
||||||
root = str(shared.user_data_dir / "presets") + "/"
|
|
||||||
return [
|
return [
|
||||||
f"{preset}.yaml",
|
f"{preset}.yaml",
|
||||||
root,
|
str(shared.user_data_dir / "presets") + "/",
|
||||||
root,
|
|
||||||
gr.update(visible=True)
|
gr.update(visible=True)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def handle_save_grammar_click(grammar_string):
|
def handle_save_grammar_click(grammar_string):
|
||||||
root = str(shared.user_data_dir / "grammars") + "/"
|
|
||||||
return [
|
return [
|
||||||
grammar_string,
|
grammar_string,
|
||||||
"My Fancy Grammar.gbnf",
|
"My Fancy Grammar.gbnf",
|
||||||
root,
|
str(shared.user_data_dir / "grammars") + "/",
|
||||||
root,
|
|
||||||
gr.update(visible=True)
|
gr.update(visible=True)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def handle_delete_grammar_click(grammar_file):
|
def handle_delete_grammar_click(grammar_file):
|
||||||
root = str(shared.user_data_dir / "grammars") + "/"
|
|
||||||
return [
|
return [
|
||||||
grammar_file,
|
grammar_file,
|
||||||
root,
|
str(shared.user_data_dir / "grammars") + "/",
|
||||||
root,
|
|
||||||
gr.update(visible=True)
|
gr.update(visible=True)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -213,7 +196,7 @@ def handle_save_user_confirm_click(name1, user_bio, your_picture, filename):
|
||||||
output = gr.update(choices=available_users, value=filename)
|
output = gr.update(choices=available_users, value=filename)
|
||||||
except Exception:
|
except Exception:
|
||||||
output = gr.update()
|
output = gr.update()
|
||||||
logger.exception("Failed to save user")
|
traceback.print_exc()
|
||||||
|
|
||||||
return [
|
return [
|
||||||
output,
|
output,
|
||||||
|
|
@ -228,7 +211,7 @@ def handle_delete_user_confirm_click(user):
|
||||||
output = chat.update_user_menu_after_deletion(index)
|
output = chat.update_user_menu_after_deletion(index)
|
||||||
except Exception:
|
except Exception:
|
||||||
output = gr.update()
|
output = gr.update()
|
||||||
logger.exception("Failed to delete user")
|
traceback.print_exc()
|
||||||
|
|
||||||
return [
|
return [
|
||||||
output,
|
output,
|
||||||
|
|
|
||||||
|
|
@ -728,8 +728,6 @@ def generate_prompt_variation(state):
|
||||||
variation = variation.rsplit("</think>", 1)[1]
|
variation = variation.rsplit("</think>", 1)[1]
|
||||||
elif "<|start|>assistant<|channel|>final<|message|>" in variation:
|
elif "<|start|>assistant<|channel|>final<|message|>" in variation:
|
||||||
variation = variation.rsplit("<|start|>assistant<|channel|>final<|message|>", 1)[1]
|
variation = variation.rsplit("<|start|>assistant<|channel|>final<|message|>", 1)[1]
|
||||||
elif "<|channel|>final<|message|>" in variation:
|
|
||||||
variation = variation.rsplit("<|channel|>final<|message|>", 1)[1]
|
|
||||||
elif "</seed:think>" in variation:
|
elif "</seed:think>" in variation:
|
||||||
variation = variation.rsplit("</seed:think>", 1)[1]
|
variation = variation.rsplit("</seed:think>", 1)[1]
|
||||||
|
|
||||||
|
|
@ -798,9 +796,6 @@ def generate(state, save_images=True):
|
||||||
if seed == -1:
|
if seed == -1:
|
||||||
seed = random.randint(0, 2**32 - 1)
|
seed = random.randint(0, 2**32 - 1)
|
||||||
|
|
||||||
# Store resolved seed back so callers (e.g. API) can access it
|
|
||||||
state['image_seed_resolved'] = seed
|
|
||||||
|
|
||||||
device = get_device()
|
device = get_device()
|
||||||
if device is None:
|
if device is None:
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
|
|
@ -919,8 +914,9 @@ def generate(state, save_images=True):
|
||||||
yield all_images, progress_bar_html()
|
yield all_images, progress_bar_html()
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
|
|
||||||
except Exception:
|
except Exception as e:
|
||||||
logger.exception("Image generation failed")
|
logger.error(f"Image generation failed: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
yield [], progress_bar_html()
|
yield [], progress_bar_html()
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,18 +42,16 @@ def create_ui():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['gpu_layers'] = gr.Slider(label="gpu-layers", minimum=-1, maximum=get_initial_gpu_layers_max(), step=1, value=shared.args.gpu_layers, info='Number of layers to offload to the GPU. -1 = auto.')
|
shared.gradio['gpu_layers'] = gr.Slider(label="gpu-layers", minimum=-1, maximum=get_initial_gpu_layers_max(), step=1, value=shared.args.gpu_layers, info='Number of layers to offload to the GPU. -1 = auto.')
|
||||||
shared.gradio['ctx_size'] = gr.Slider(label='ctx-size', minimum=0, maximum=1048576, step=1024, value=shared.args.ctx_size, info='Context length. 0 = auto for llama.cpp (requires gpu-layers=-1), 8192 for other loaders. Common values: 4096, 8192, 16384, 32768, 65536, 131072.')
|
shared.gradio['ctx_size'] = gr.Slider(label='ctx-size', minimum=0, maximum=131072, step=256, value=shared.args.ctx_size, info='Context length. llama.cpp: 0 = auto if gpu-layers is also -1. Common values: 4096, 8192, 16384, 32768, 65536, 131072.')
|
||||||
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
|
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
|
||||||
shared.gradio['attn_implementation'] = gr.Dropdown(label="attn-implementation", choices=['sdpa', 'eager', 'flash_attention_2'], value=shared.args.attn_implementation, info='Attention implementation.')
|
shared.gradio['attn_implementation'] = gr.Dropdown(label="attn-implementation", choices=['sdpa', 'eager', 'flash_attention_2'], value=shared.args.attn_implementation, info='Attention implementation.')
|
||||||
shared.gradio['cache_type'] = gr.Dropdown(label="cache-type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q7', 'q6', 'q5', 'q4', 'q3', 'q2'], value=shared.args.cache_type, allow_custom_value=True, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8. For ExLlamaV3, you can type custom combinations for separate k/v bits (e.g. q4_q8).')
|
shared.gradio['cache_type'] = gr.Dropdown(label="cache-type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q7', 'q6', 'q5', 'q4', 'q3', 'q2'], value=shared.args.cache_type, allow_custom_value=True, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8. For ExLlamaV3, you can type custom combinations for separate k/v bits (e.g. q4_q8).')
|
||||||
shared.gradio['fit_target'] = gr.Textbox(label='fit-target', value=shared.args.fit_target, info='Target VRAM margin per device for auto GPU layers (MiB). Comma-separated list for multiple devices.')
|
shared.gradio['fit_target'] = gr.Textbox(label='fit-target', value=shared.args.fit_target, info='Target VRAM margin per device for auto GPU layers (MiB). Comma-separated list for multiple devices. Default: 1024.')
|
||||||
shared.gradio['tp_backend'] = gr.Dropdown(label="tp-backend", choices=['native', 'nccl'], value=shared.args.tp_backend, info='The backend for tensor parallelism.')
|
shared.gradio['tp_backend'] = gr.Dropdown(label="tp-backend", choices=['native', 'nccl'], value=shared.args.tp_backend, info='The backend for tensor parallelism.')
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['vram_info'] = gr.HTML(value=get_initial_vram_info())
|
shared.gradio['vram_info'] = gr.HTML(value=get_initial_vram_info())
|
||||||
if not shared.args.portable:
|
shared.gradio['cpu_moe'] = gr.Checkbox(label="cpu-moe", value=shared.args.cpu_moe, info='Move the experts to the CPU. Saves VRAM on MoE models.')
|
||||||
shared.gradio['ik'] = gr.Checkbox(label="ik", value=shared.args.ik, info='Use ik_llama.cpp instead of upstream llama.cpp.')
|
|
||||||
|
|
||||||
shared.gradio['streaming_llm'] = gr.Checkbox(label="streaming-llm", value=shared.args.streaming_llm, info='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
|
shared.gradio['streaming_llm'] = gr.Checkbox(label="streaming-llm", value=shared.args.streaming_llm, info='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
|
||||||
shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
|
shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
|
||||||
shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit)
|
shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit)
|
||||||
|
|
@ -66,13 +64,13 @@ def create_ui():
|
||||||
)
|
)
|
||||||
|
|
||||||
# Multimodal
|
# Multimodal
|
||||||
with gr.Accordion("Multimodal (vision)", open=False) as shared.gradio['mmproj_accordion']:
|
with gr.Accordion("Multimodal (vision)", open=False, elem_classes='tgw-accordion') as shared.gradio['mmproj_accordion']:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['mmproj'] = gr.Dropdown(label="mmproj file", choices=utils.get_available_mmproj(), value=lambda: shared.args.mmproj or 'None', elem_classes='slim-dropdown', info=f'Select a file that matches your model. Must be placed in {shared.user_data_dir}/mmproj/', interactive=not mu)
|
shared.gradio['mmproj'] = gr.Dropdown(label="mmproj file", choices=utils.get_available_mmproj(), value=lambda: shared.args.mmproj or 'None', elem_classes='slim-dropdown', info=f'Select a file that matches your model. Must be placed in {shared.user_data_dir}/mmproj/', interactive=not mu)
|
||||||
ui.create_refresh_button(shared.gradio['mmproj'], lambda: None, lambda: {'choices': utils.get_available_mmproj()}, 'refresh-button', interactive=not mu)
|
ui.create_refresh_button(shared.gradio['mmproj'], lambda: None, lambda: {'choices': utils.get_available_mmproj()}, 'refresh-button', interactive=not mu)
|
||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
with gr.Accordion("Speculative decoding", open=False) as shared.gradio['speculative_decoding_accordion']:
|
with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']:
|
||||||
shared.gradio['draft_max'] = gr.Number(label="draft-max", precision=0, step=1, value=shared.args.draft_max, info='Maximum number of tokens to draft for speculative decoding. Recommended: 4 for draft model, 64 for n-gram.')
|
shared.gradio['draft_max'] = gr.Number(label="draft-max", precision=0, step=1, value=shared.args.draft_max, info='Maximum number of tokens to draft for speculative decoding. Recommended: 4 for draft model, 64 for n-gram.')
|
||||||
|
|
||||||
gr.Markdown('#### Draft model')
|
gr.Markdown('#### Draft model')
|
||||||
|
|
@ -91,7 +89,7 @@ def create_ui():
|
||||||
shared.gradio['spec_ngram_min_hits'] = gr.Number(label="spec-ngram-min-hits", precision=0, step=1, value=shared.args.spec_ngram_min_hits, info='Minimum n-gram hits for ngram-map speculative decoding.', visible=shared.args.spec_type != 'none')
|
shared.gradio['spec_ngram_min_hits'] = gr.Number(label="spec-ngram-min-hits", precision=0, step=1, value=shared.args.spec_ngram_min_hits, info='Minimum n-gram hits for ngram-map speculative decoding.', visible=shared.args.spec_type != 'none')
|
||||||
|
|
||||||
gr.Markdown("## Other options")
|
gr.Markdown("## Other options")
|
||||||
with gr.Accordion("See more options", open=False):
|
with gr.Accordion("See more options", open=False, elem_classes='tgw-accordion'):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['parallel'] = gr.Slider(label="parallel", minimum=1, step=1, maximum=64, value=shared.args.parallel, info='Number of parallel request slots for the API. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set ctx_size to 32768.')
|
shared.gradio['parallel'] = gr.Slider(label="parallel", minimum=1, step=1, maximum=64, value=shared.args.parallel, info='Number of parallel request slots for the API. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set ctx_size to 32768.')
|
||||||
|
|
@ -100,17 +98,19 @@ def create_ui():
|
||||||
shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size)
|
shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size)
|
||||||
shared.gradio['ubatch_size'] = gr.Slider(label="ubatch_size", minimum=1, maximum=4096, step=1, value=shared.args.ubatch_size)
|
shared.gradio['ubatch_size'] = gr.Slider(label="ubatch_size", minimum=1, maximum=4096, step=1, value=shared.args.ubatch_size)
|
||||||
shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')
|
shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')
|
||||||
shared.gradio['extra_flags'] = gr.Textbox(label='extra-flags', info='Extra flags to pass to llama-server. Example: --jinja --rpc 192.168.1.100:50052', value=shared.args.extra_flags)
|
shared.gradio['extra_flags'] = gr.Textbox(label='extra-flags', info='Additional flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"', value=shared.args.extra_flags)
|
||||||
shared.gradio['cpu_memory'] = gr.Number(label="Maximum CPU memory in GiB. Use this for CPU offloading.", value=shared.args.cpu_memory)
|
shared.gradio['cpu_memory'] = gr.Number(label="Maximum CPU memory in GiB. Use this for CPU offloading.", value=shared.args.cpu_memory)
|
||||||
|
shared.gradio['alpha_value'] = gr.Number(label='alpha_value', value=shared.args.alpha_value, precision=2, info='Positional embeddings alpha factor for NTK RoPE scaling. Recommended values (NTKv1): 1.75 for 1.5x context, 2.5 for 2x context. Use either this or compress_pos_emb, not both.')
|
||||||
|
shared.gradio['rope_freq_base'] = gr.Number(label='rope_freq_base', value=shared.args.rope_freq_base, precision=0, info='Positional embeddings frequency base for NTK RoPE scaling. Related to alpha_value by rope_freq_base = 10000 * alpha_value ^ (64 / 63). 0 = from model.')
|
||||||
|
shared.gradio['compress_pos_emb'] = gr.Number(label='compress_pos_emb', value=shared.args.compress_pos_emb, precision=2, info='Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.')
|
||||||
shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype, info='Used by load-in-4bit.')
|
shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype, info='Used by load-in-4bit.')
|
||||||
shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type, info='Used by load-in-4bit.')
|
shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type, info='Used by load-in-4bit.')
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='Use PyTorch in CPU mode.')
|
shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='Use PyTorch in CPU mode.')
|
||||||
shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk)
|
shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk)
|
||||||
shared.gradio['cpu_moe'] = gr.Checkbox(label="cpu-moe", value=shared.args.cpu_moe, info='Move the experts to the CPU. Saves VRAM on MoE models.')
|
|
||||||
shared.gradio['row_split'] = gr.Checkbox(label="row_split", value=shared.args.row_split, info='Split the model by rows across GPUs. This may improve multi-gpu performance.')
|
shared.gradio['row_split'] = gr.Checkbox(label="row_split", value=shared.args.row_split, info='Split the model by rows across GPUs. This may improve multi-gpu performance.')
|
||||||
shared.gradio['no_kv_offload'] = gr.Checkbox(label="no_kv_offload", value=shared.args.no_kv_offload, info='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces performance.')
|
shared.gradio['no_kv_offload'] = gr.Checkbox(label="no_kv_offload", value=shared.args.no_kv_offload, info='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.')
|
||||||
shared.gradio['no_mmap'] = gr.Checkbox(label="no-mmap", value=shared.args.no_mmap)
|
shared.gradio['no_mmap'] = gr.Checkbox(label="no-mmap", value=shared.args.no_mmap)
|
||||||
shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock)
|
shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock)
|
||||||
shared.gradio['numa'] = gr.Checkbox(label="numa", value=shared.args.numa, info='NUMA support can help on some systems with non-uniform memory access.')
|
shared.gradio['numa'] = gr.Checkbox(label="numa", value=shared.args.numa, info='NUMA support can help on some systems with non-uniform memory access.')
|
||||||
|
|
@ -137,7 +137,7 @@ def create_ui():
|
||||||
ui.create_refresh_button(shared.gradio['customized_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu)
|
ui.create_refresh_button(shared.gradio['customized_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu)
|
||||||
|
|
||||||
shared.gradio['customized_template_submit'] = gr.Button("Submit", variant="primary", interactive=not mu)
|
shared.gradio['customized_template_submit'] = gr.Button("Submit", variant="primary", interactive=not mu)
|
||||||
gr.Markdown("This allows you to set a customized template for the model currently selected in the \"Model loader\" menu. Whenever the model gets loaded, this template will be used in place of the template specified in the model's metadata, which sometimes is wrong.")
|
gr.Markdown("This allows you to set a customized template for the model currently selected in the \"Model loader\" menu. Whenever the model gets loaded, this template will be used in place of the template specified in the model's medatada, which sometimes is wrong.")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
|
shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
|
||||||
|
|
@ -225,14 +225,16 @@ def load_model_wrapper(selected_model, loader, autoload=False):
|
||||||
else:
|
else:
|
||||||
yield f"Failed to load `{selected_model}`."
|
yield f"Failed to load `{selected_model}`."
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception('Failed to load the model.')
|
exc = traceback.format_exc()
|
||||||
yield traceback.format_exc().replace('\n', '\n\n')
|
logger.error('Failed to load the model.')
|
||||||
|
print(exc)
|
||||||
|
yield exc.replace('\n', '\n\n')
|
||||||
|
|
||||||
|
|
||||||
def load_lora_wrapper(selected_loras):
|
def load_lora_wrapper(selected_loras):
|
||||||
yield ("Applying the following LoRAs to {}:\n\n{}".format(shared.model_name, '\n'.join(selected_loras)))
|
yield ("Applying the following LoRAs to {}:\n\n{}".format(shared.model_name, '\n'.join(selected_loras)))
|
||||||
add_lora_to_model(selected_loras)
|
add_lora_to_model(selected_loras)
|
||||||
yield ("Successfully applied the LoRAs")
|
yield ("Successfuly applied the LoRAs")
|
||||||
|
|
||||||
|
|
||||||
def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), return_links=False, check=False):
|
def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), return_links=False, check=False):
|
||||||
|
|
@ -386,12 +388,8 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
|
||||||
def update_truncation_length(current_length, state):
|
def update_truncation_length(current_length, state):
|
||||||
if 'loader' in state:
|
if 'loader' in state:
|
||||||
if state['loader'].lower().startswith('exllama') or state['loader'] == 'llama.cpp':
|
if state['loader'].lower().startswith('exllama') or state['loader'] == 'llama.cpp':
|
||||||
if state['ctx_size'] > 0:
|
|
||||||
return state['ctx_size']
|
return state['ctx_size']
|
||||||
|
|
||||||
# ctx_size == 0 means auto: use the actual value from the server
|
|
||||||
return shared.settings['truncation_length']
|
|
||||||
|
|
||||||
return current_length
|
return current_length
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from modules.text_generation import (
|
||||||
get_token_ids,
|
get_token_ids,
|
||||||
stop_everything_event
|
stop_everything_event
|
||||||
)
|
)
|
||||||
from modules.utils import gradio, sanitize_filename
|
from modules.utils import gradio
|
||||||
|
|
||||||
_notebook_file_lock = threading.Lock()
|
_notebook_file_lock = threading.Lock()
|
||||||
_notebook_auto_save_timer = None
|
_notebook_auto_save_timer = None
|
||||||
|
|
@ -202,7 +202,6 @@ def handle_new_prompt():
|
||||||
|
|
||||||
|
|
||||||
def handle_delete_prompt_confirm_notebook(prompt_name):
|
def handle_delete_prompt_confirm_notebook(prompt_name):
|
||||||
prompt_name = sanitize_filename(prompt_name)
|
|
||||||
available_prompts = utils.get_available_prompts()
|
available_prompts = utils.get_available_prompts()
|
||||||
current_index = available_prompts.index(prompt_name) if prompt_name in available_prompts else 0
|
current_index = available_prompts.index(prompt_name) if prompt_name in available_prompts else 0
|
||||||
|
|
||||||
|
|
@ -234,8 +233,6 @@ def handle_rename_prompt_click_notebook(current_name):
|
||||||
|
|
||||||
|
|
||||||
def handle_rename_prompt_confirm_notebook(new_name, current_name):
|
def handle_rename_prompt_confirm_notebook(new_name, current_name):
|
||||||
new_name = sanitize_filename(new_name)
|
|
||||||
current_name = sanitize_filename(current_name)
|
|
||||||
old_path = shared.user_data_dir / "logs" / "notebook" / f"{current_name}.txt"
|
old_path = shared.user_data_dir / "logs" / "notebook" / f"{current_name}.txt"
|
||||||
new_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt"
|
new_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt"
|
||||||
|
|
||||||
|
|
@ -252,7 +249,6 @@ def handle_rename_prompt_confirm_notebook(new_name, current_name):
|
||||||
|
|
||||||
def autosave_prompt(text, prompt_name):
|
def autosave_prompt(text, prompt_name):
|
||||||
"""Automatically save the text to the selected prompt file"""
|
"""Automatically save the text to the selected prompt file"""
|
||||||
prompt_name = sanitize_filename(prompt_name)
|
|
||||||
if prompt_name and text.strip():
|
if prompt_name and text.strip():
|
||||||
prompt_path = shared.user_data_dir / "logs" / "notebook" / f"{prompt_name}.txt"
|
prompt_path = shared.user_data_dir / "logs" / "notebook" / f"{prompt_name}.txt"
|
||||||
prompt_path.parent.mkdir(parents=True, exist_ok=True)
|
prompt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
|
||||||
|
|
@ -37,10 +37,10 @@ def create_ui():
|
||||||
shared.gradio['dynamic_temperature'] = gr.Checkbox(value=shared.settings['dynamic_temperature'], label='dynamic_temperature')
|
shared.gradio['dynamic_temperature'] = gr.Checkbox(value=shared.settings['dynamic_temperature'], label='dynamic_temperature')
|
||||||
|
|
||||||
gr.Markdown('## Curve cutoff')
|
gr.Markdown('## Curve cutoff')
|
||||||
shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=shared.settings['top_p'], step=0.01, label='top_p')
|
|
||||||
shared.gradio['top_k'] = gr.Slider(0, 200, value=shared.settings['top_k'], step=1, label='top_k')
|
|
||||||
shared.gradio['min_p'] = gr.Slider(0.0, 1.0, value=shared.settings['min_p'], step=0.01, label='min_p')
|
shared.gradio['min_p'] = gr.Slider(0.0, 1.0, value=shared.settings['min_p'], step=0.01, label='min_p')
|
||||||
shared.gradio['top_n_sigma'] = gr.Slider(0.0, 5.0, value=shared.settings['top_n_sigma'], step=0.01, label='top_n_sigma')
|
shared.gradio['top_n_sigma'] = gr.Slider(0.0, 5.0, value=shared.settings['top_n_sigma'], step=0.01, label='top_n_sigma')
|
||||||
|
shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=shared.settings['top_p'], step=0.01, label='top_p')
|
||||||
|
shared.gradio['top_k'] = gr.Slider(0, 200, value=shared.settings['top_k'], step=1, label='top_k')
|
||||||
shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=shared.settings['typical_p'], step=0.01, label='typical_p')
|
shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=shared.settings['typical_p'], step=0.01, label='typical_p')
|
||||||
shared.gradio['xtc_threshold'] = gr.Slider(0, 0.5, value=shared.settings['xtc_threshold'], step=0.01, label='xtc_threshold', info='If 2 or more tokens have probability above this threshold, consider removing all but the last one.')
|
shared.gradio['xtc_threshold'] = gr.Slider(0, 0.5, value=shared.settings['xtc_threshold'], step=0.01, label='xtc_threshold', info='If 2 or more tokens have probability above this threshold, consider removing all but the last one.')
|
||||||
shared.gradio['xtc_probability'] = gr.Slider(0, 1, value=shared.settings['xtc_probability'], step=0.01, label='xtc_probability', info='Probability that the removal will actually happen. 0 disables the sampler. 1 makes it always happen.')
|
shared.gradio['xtc_probability'] = gr.Slider(0, 1, value=shared.settings['xtc_probability'], step=0.01, label='xtc_probability', info='Probability that the removal will actually happen. 0 disables the sampler. 1 makes it always happen.')
|
||||||
|
|
@ -73,7 +73,7 @@ def create_ui():
|
||||||
gr.Markdown('## Other options')
|
gr.Markdown('## Other options')
|
||||||
shared.gradio['do_sample'] = gr.Checkbox(value=shared.settings['do_sample'], label='do_sample')
|
shared.gradio['do_sample'] = gr.Checkbox(value=shared.settings['do_sample'], label='do_sample')
|
||||||
shared.gradio['temperature_last'] = gr.Checkbox(value=shared.settings['temperature_last'], label='temperature_last', info='Moves temperature/dynamic temperature/quadratic sampling to the end of the sampler stack, ignoring their positions in "Sampler priority".')
|
shared.gradio['temperature_last'] = gr.Checkbox(value=shared.settings['temperature_last'], label='temperature_last', info='Moves temperature/dynamic temperature/quadratic sampling to the end of the sampler stack, ignoring their positions in "Sampler priority".')
|
||||||
shared.gradio['sampler_priority'] = gr.DragDrop(value=shared.settings['sampler_priority'], label='Sampler priority', info='Parameter names separated by new lines or commas.', elem_classes=['add_scrollbar'])
|
shared.gradio['sampler_priority'] = gr.Textbox(value=shared.settings['sampler_priority'], lines=10, label='Sampler priority', info='Parameter names separated by new lines or commas.', elem_classes=['add_scrollbar'])
|
||||||
shared.gradio['dry_sequence_breakers'] = gr.Textbox(value=shared.settings['dry_sequence_breakers'], label='dry_sequence_breakers', info='Tokens across which sequence matching is not continued. Specified as a comma-separated list of quoted strings.')
|
shared.gradio['dry_sequence_breakers'] = gr.Textbox(value=shared.settings['dry_sequence_breakers'], label='dry_sequence_breakers', info='Tokens across which sequence matching is not continued. Specified as a comma-separated list of quoted strings.')
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ def create_ui():
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
gr.Markdown("## Extensions & flags")
|
gr.Markdown("## Extensions & flags")
|
||||||
shared.gradio['save_settings'] = gr.Button(f'Save extensions settings to {shared.user_data_dir}/settings.yaml', interactive=not mu)
|
shared.gradio['save_settings'] = gr.Button(f'Save extensions settings to {shared.user_data_dir}/settings.yaml', elem_classes='refresh-button', interactive=not mu)
|
||||||
shared.gradio['reset_interface'] = gr.Button("Apply flags/extensions and restart", interactive=not mu)
|
shared.gradio['reset_interface'] = gr.Button("Apply flags/extensions and restart", interactive=not mu)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
|
@ -30,7 +30,7 @@ def create_ui():
|
||||||
if not mu:
|
if not mu:
|
||||||
shared.gradio['save_settings'].click(
|
shared.gradio['save_settings'].click(
|
||||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
handle_save_settings, gradio('interface_state', 'preset_menu', 'extensions_menu', 'show_controls', 'theme_state'), gradio('save_contents', 'save_filename', 'save_root', 'save_root_state', 'file_saver'), show_progress=False)
|
handle_save_settings, gradio('interface_state', 'preset_menu', 'extensions_menu', 'show_controls', 'theme_state'), gradio('save_contents', 'save_filename', 'save_root', 'file_saver'), show_progress=False)
|
||||||
|
|
||||||
shared.gradio['toggle_dark_mode'].click(
|
shared.gradio['toggle_dark_mode'].click(
|
||||||
lambda x: 'dark' if x == 'light' else 'light', gradio('theme_state'), gradio('theme_state')).then(
|
lambda x: 'dark' if x == 'light' else 'light', gradio('theme_state'), gradio('theme_state')).then(
|
||||||
|
|
@ -51,12 +51,10 @@ def create_ui():
|
||||||
|
|
||||||
def handle_save_settings(state, preset, extensions, show_controls, theme):
|
def handle_save_settings(state, preset, extensions, show_controls, theme):
|
||||||
contents = ui.save_settings(state, preset, extensions, show_controls, theme, manual_save=True)
|
contents = ui.save_settings(state, preset, extensions, show_controls, theme, manual_save=True)
|
||||||
root = str(shared.user_data_dir) + "/"
|
|
||||||
return [
|
return [
|
||||||
contents,
|
contents,
|
||||||
"settings.yaml",
|
"settings.yaml",
|
||||||
root,
|
str(shared.user_data_dir) + "/",
|
||||||
root,
|
|
||||||
gr.update(visible=True)
|
gr.update(visible=True)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -95,6 +93,8 @@ def set_interface_arguments(extensions, bool_active):
|
||||||
setattr(shared.args, k, False)
|
setattr(shared.args, k, False)
|
||||||
for k in bool_active:
|
for k in bool_active:
|
||||||
setattr(shared.args, k, True)
|
setattr(shared.args, k, True)
|
||||||
|
if k == 'api':
|
||||||
|
shared.add_extension('openai', last=True)
|
||||||
|
|
||||||
shared.need_restart = True
|
shared.need_restart = True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -47,10 +47,6 @@ def save_file(fname, contents):
|
||||||
logger.error(f'Invalid file path: \"{fname}\"')
|
logger.error(f'Invalid file path: \"{fname}\"')
|
||||||
return
|
return
|
||||||
|
|
||||||
if Path(abs_path_str).suffix.lower() not in ('.yaml', '.yml', '.json', '.txt', '.gbnf'):
|
|
||||||
logger.error(f'Refusing to save file with disallowed extension: \"{fname}\"')
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(abs_path_str, 'w', encoding='utf-8') as f:
|
with open(abs_path_str, 'w', encoding='utf-8') as f:
|
||||||
f.write(contents)
|
f.write(contents)
|
||||||
|
|
||||||
|
|
@ -81,6 +77,14 @@ def atoi(text):
|
||||||
return int(text) if text.isdigit() else text.lower()
|
return int(text) if text.isdigit() else text.lower()
|
||||||
|
|
||||||
|
|
||||||
|
# Replace multiple string pairs in a string
|
||||||
|
def replace_all(text, dic):
|
||||||
|
for i, j in dic.items():
|
||||||
|
text = text.replace(i, j)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
def natural_keys(text):
|
def natural_keys(text):
|
||||||
return [atoi(c) for c in re.split(r'(\d+)', text)]
|
return [atoi(c) for c in re.split(r'(\d+)', text)]
|
||||||
|
|
||||||
|
|
@ -105,9 +109,6 @@ def resolve_model_path(model_name_or_path, image_model=False):
|
||||||
before the default models directory.
|
before the default models directory.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if model_name_or_path is None:
|
|
||||||
raise FileNotFoundError("No model specified.")
|
|
||||||
|
|
||||||
path_candidate = Path(model_name_or_path)
|
path_candidate = Path(model_name_or_path)
|
||||||
if path_candidate.exists():
|
if path_candidate.exists():
|
||||||
return path_candidate
|
return path_candidate
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import html
|
import html
|
||||||
import ipaddress
|
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import socket
|
import urllib.request
|
||||||
from concurrent.futures import as_completed
|
from concurrent.futures import as_completed
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from urllib.parse import parse_qs, quote_plus, urljoin, urlparse
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
@ -14,60 +13,34 @@ from modules import shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
|
|
||||||
def _validate_url(url):
|
|
||||||
"""Validate that a URL is safe to fetch (not targeting private/internal networks)."""
|
|
||||||
parsed = urlparse(url)
|
|
||||||
if parsed.scheme not in ('http', 'https'):
|
|
||||||
raise ValueError(f"Unsupported URL scheme: {parsed.scheme}")
|
|
||||||
|
|
||||||
hostname = parsed.hostname
|
|
||||||
if not hostname:
|
|
||||||
raise ValueError("No hostname in URL")
|
|
||||||
|
|
||||||
# Resolve hostname and check all returned addresses
|
|
||||||
try:
|
|
||||||
for family, _, _, _, sockaddr in socket.getaddrinfo(hostname, None):
|
|
||||||
ip = ipaddress.ip_address(sockaddr[0])
|
|
||||||
if not ip.is_global:
|
|
||||||
raise ValueError(f"Access to non-public address {ip} is blocked")
|
|
||||||
except socket.gaierror:
|
|
||||||
raise ValueError(f"Could not resolve hostname: {hostname}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_timestamp():
|
def get_current_timestamp():
|
||||||
"""Returns the current time in 24-hour format"""
|
"""Returns the current time in 24-hour format"""
|
||||||
return datetime.now().strftime('%b %d, %Y %H:%M')
|
return datetime.now().strftime('%b %d, %Y %H:%M')
|
||||||
|
|
||||||
|
|
||||||
def download_web_page(url, timeout=10, include_links=False):
|
def download_web_page(url, timeout=10):
|
||||||
"""
|
"""
|
||||||
Download a web page and extract its main content as Markdown text.
|
Download a web page and convert its HTML content to structured Markdown text.
|
||||||
"""
|
"""
|
||||||
import trafilatura
|
import html2text
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_validate_url(url)
|
|
||||||
headers = {
|
headers = {
|
||||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36'
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
||||||
}
|
}
|
||||||
max_redirects = 5
|
response = requests.get(url, headers=headers, timeout=timeout)
|
||||||
for _ in range(max_redirects):
|
response.raise_for_status() # Raise an exception for bad status codes
|
||||||
response = requests.get(url, headers=headers, timeout=timeout, allow_redirects=False)
|
|
||||||
if response.is_redirect and 'Location' in response.headers:
|
|
||||||
url = urljoin(url, response.headers['Location'])
|
|
||||||
_validate_url(url)
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
response.raise_for_status()
|
# Initialize the HTML to Markdown converter
|
||||||
|
h = html2text.HTML2Text()
|
||||||
|
h.body_width = 0
|
||||||
|
h.ignore_images = True
|
||||||
|
h.ignore_links = True
|
||||||
|
|
||||||
result = trafilatura.extract(
|
# Convert the HTML to Markdown
|
||||||
response.text,
|
markdown_text = h.handle(response.text)
|
||||||
include_links=include_links,
|
|
||||||
output_format='markdown',
|
return markdown_text
|
||||||
url=url
|
|
||||||
)
|
|
||||||
return result or ""
|
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
logger.error(f"Error downloading {url}: {e}")
|
logger.error(f"Error downloading {url}: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
@ -76,51 +49,35 @@ def download_web_page(url, timeout=10, include_links=False):
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def perform_web_search(query, num_pages=3, max_workers=5, timeout=10, fetch_content=True):
|
def perform_web_search(query, num_pages=3, max_workers=5, timeout=10):
|
||||||
"""Perform web search and return results, optionally with page content"""
|
"""Perform web search and return results with content"""
|
||||||
try:
|
try:
|
||||||
search_url = f"https://html.duckduckgo.com/html/?q={quote_plus(query)}"
|
search_url = f"https://html.duckduckgo.com/html/?q={quote_plus(query)}"
|
||||||
|
|
||||||
agents = [
|
agents = [
|
||||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36",
|
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
|
||||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36"
|
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36"
|
||||||
]
|
]
|
||||||
|
|
||||||
response = requests.get(search_url, headers={'User-Agent': random.choice(agents)}, timeout=timeout)
|
response_text = ""
|
||||||
response.raise_for_status()
|
req = urllib.request.Request(search_url, headers={'User-Agent': random.choice(agents)})
|
||||||
response_text = response.text
|
with urllib.request.urlopen(req, timeout=timeout) as response:
|
||||||
|
response_text = response.read().decode('utf-8')
|
||||||
|
|
||||||
# Extract results - title and URL come from the same <a class="result__a"> element
|
# Extract results with regex
|
||||||
result_links = re.findall(r'<a[^>]*class="[^"]*result__a[^"]*"[^>]*>(.*?)</a>', response_text, re.DOTALL)
|
titles = re.findall(r'<a[^>]*class="[^"]*result__a[^"]*"[^>]*>(.*?)</a>', response_text, re.DOTALL)
|
||||||
result_tags = re.findall(r'<a([^>]*class="[^"]*result__a[^"]*"[^>]*)>', response_text, re.DOTALL)
|
urls = re.findall(r'<a[^>]*class="[^"]*result__url[^"]*"[^>]*>(.*?)</a>', response_text, re.DOTALL)
|
||||||
|
|
||||||
# Prepare download tasks
|
# Prepare download tasks
|
||||||
download_tasks = []
|
download_tasks = []
|
||||||
for i, (tag_attrs, raw_title) in enumerate(zip(result_tags, result_links)):
|
for i in range(min(len(titles), len(urls), num_pages)):
|
||||||
if num_pages is not None and i >= num_pages:
|
url = f"https://{urls[i].strip()}"
|
||||||
break
|
title = re.sub(r'<[^>]+>', '', titles[i]).strip()
|
||||||
# Extract href and resolve the actual URL from DuckDuckGo's redirect link
|
title = html.unescape(title)
|
||||||
href_match = re.search(r'href="([^"]*)"', tag_attrs)
|
download_tasks.append((url, title, i))
|
||||||
if not href_match:
|
|
||||||
continue
|
|
||||||
uddg = parse_qs(urlparse(html.unescape(href_match.group(1))).query).get('uddg', [''])[0]
|
|
||||||
if not uddg:
|
|
||||||
continue
|
|
||||||
title = html.unescape(re.sub(r'<[^>]+>', '', raw_title).strip())
|
|
||||||
download_tasks.append((uddg, title, len(download_tasks)))
|
|
||||||
|
|
||||||
search_results = [None] * len(download_tasks) # Pre-allocate to maintain order
|
search_results = [None] * len(download_tasks) # Pre-allocate to maintain order
|
||||||
|
|
||||||
if not fetch_content:
|
|
||||||
for url, title, index in download_tasks:
|
|
||||||
search_results[index] = {
|
|
||||||
'title': title,
|
|
||||||
'url': url,
|
|
||||||
'content': ''
|
|
||||||
}
|
|
||||||
|
|
||||||
return search_results
|
|
||||||
|
|
||||||
# Download pages in parallel
|
# Download pages in parallel
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
# Submit all download tasks
|
# Submit all download tasks
|
||||||
|
|
|
||||||
50
one_click.py
50
one_click.py
|
|
@ -91,7 +91,7 @@ def get_gpu_choice():
|
||||||
"What is your GPU?",
|
"What is your GPU?",
|
||||||
{
|
{
|
||||||
'A': 'NVIDIA',
|
'A': 'NVIDIA',
|
||||||
'B': 'AMD - Linux only, ROCm 7.2',
|
'B': 'AMD - Linux/macOS only, requires ROCm 6.4',
|
||||||
'C': 'Apple M Series',
|
'C': 'Apple M Series',
|
||||||
'D': 'Intel Arc (beta)',
|
'D': 'Intel Arc (beta)',
|
||||||
'N': 'CPU mode'
|
'N': 'CPU mode'
|
||||||
|
|
@ -111,17 +111,18 @@ def get_gpu_choice():
|
||||||
def get_pytorch_install_command(gpu_choice):
|
def get_pytorch_install_command(gpu_choice):
|
||||||
"""Get PyTorch installation command based on GPU choice"""
|
"""Get PyTorch installation command based on GPU choice"""
|
||||||
base_cmd = f"python -m pip install torch=={TORCH_VERSION} "
|
base_cmd = f"python -m pip install torch=={TORCH_VERSION} "
|
||||||
pypi_fallback = " --extra-index-url https://pypi.org/simple/"
|
|
||||||
|
|
||||||
if gpu_choice == "NVIDIA_CUDA128":
|
if gpu_choice == "NVIDIA_CUDA128":
|
||||||
return base_cmd + "--index-url https://download.pytorch.org/whl/cu128" + pypi_fallback
|
return base_cmd + "--index-url https://download.pytorch.org/whl/cu128"
|
||||||
elif gpu_choice == "AMD":
|
elif gpu_choice == "AMD":
|
||||||
py_tag = f"cp{PYTHON_VERSION.replace('.', '')}"
|
return base_cmd + "--index-url https://download.pytorch.org/whl/rocm6.4"
|
||||||
return f"python -m pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-{TORCH_VERSION}%2Brocm7.2.0.lw.git7e1940d4-{py_tag}-{py_tag}-linux_x86_64.whl --find-links https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/"
|
|
||||||
elif gpu_choice in ["APPLE", "NONE"]:
|
elif gpu_choice in ["APPLE", "NONE"]:
|
||||||
return base_cmd + "--index-url https://download.pytorch.org/whl/cpu" + pypi_fallback
|
return base_cmd + "--index-url https://download.pytorch.org/whl/cpu"
|
||||||
elif gpu_choice == "INTEL":
|
elif gpu_choice == "INTEL":
|
||||||
return base_cmd + "--index-url https://download.pytorch.org/whl/xpu"
|
if is_linux():
|
||||||
|
return "python -m pip install torch==2.1.0a0 intel-extension-for-pytorch==2.1.10+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
|
||||||
|
else:
|
||||||
|
return "python -m pip install torch==2.1.0a0 intel-extension-for-pytorch==2.1.10 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
|
||||||
else:
|
else:
|
||||||
return base_cmd
|
return base_cmd
|
||||||
|
|
||||||
|
|
@ -129,17 +130,16 @@ def get_pytorch_install_command(gpu_choice):
|
||||||
def get_pytorch_update_command(gpu_choice):
|
def get_pytorch_update_command(gpu_choice):
|
||||||
"""Get PyTorch update command based on GPU choice"""
|
"""Get PyTorch update command based on GPU choice"""
|
||||||
base_cmd = f"python -m pip install --upgrade torch=={TORCH_VERSION} "
|
base_cmd = f"python -m pip install --upgrade torch=={TORCH_VERSION} "
|
||||||
pypi_fallback = " --extra-index-url https://pypi.org/simple/"
|
|
||||||
|
|
||||||
if gpu_choice == "NVIDIA_CUDA128":
|
if gpu_choice == "NVIDIA_CUDA128":
|
||||||
return f"{base_cmd}--index-url https://download.pytorch.org/whl/cu128" + pypi_fallback
|
return f"{base_cmd} --index-url https://download.pytorch.org/whl/cu128"
|
||||||
elif gpu_choice == "AMD":
|
elif gpu_choice == "AMD":
|
||||||
py_tag = f"cp{PYTHON_VERSION.replace('.', '')}"
|
return f"{base_cmd} --index-url https://download.pytorch.org/whl/rocm6.4"
|
||||||
return f"python -m pip install --upgrade https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-{TORCH_VERSION}%2Brocm7.2.0.lw.git7e1940d4-{py_tag}-{py_tag}-linux_x86_64.whl --find-links https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/"
|
|
||||||
elif gpu_choice in ["APPLE", "NONE"]:
|
elif gpu_choice in ["APPLE", "NONE"]:
|
||||||
return f"{base_cmd}--index-url https://download.pytorch.org/whl/cpu" + pypi_fallback
|
return f"{base_cmd} --index-url https://download.pytorch.org/whl/cpu"
|
||||||
elif gpu_choice == "INTEL":
|
elif gpu_choice == "INTEL":
|
||||||
return f"{base_cmd}--index-url https://download.pytorch.org/whl/xpu"
|
intel_extension = "intel-extension-for-pytorch==2.1.10+xpu" if is_linux() else "intel-extension-for-pytorch==2.1.10"
|
||||||
|
return f"{base_cmd} {intel_extension} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
|
||||||
else:
|
else:
|
||||||
return base_cmd
|
return base_cmd
|
||||||
|
|
||||||
|
|
@ -194,8 +194,6 @@ def run_cmd(cmd, assert_success=False, environment=False, capture_output=False,
|
||||||
if environment:
|
if environment:
|
||||||
if is_windows():
|
if is_windows():
|
||||||
conda_bat_path = os.path.join(script_dir, "installer_files", "conda", "condabin", "conda.bat")
|
conda_bat_path = os.path.join(script_dir, "installer_files", "conda", "condabin", "conda.bat")
|
||||||
python_path = os.path.join(conda_env_path, "python.exe")
|
|
||||||
cmd = cmd.replace("python ", f'"{python_path}" ')
|
|
||||||
cmd = f'"{conda_bat_path}" activate "{conda_env_path}" >nul && {cmd}'
|
cmd = f'"{conda_bat_path}" activate "{conda_env_path}" >nul && {cmd}'
|
||||||
else:
|
else:
|
||||||
conda_sh_path = os.path.join(script_dir, "installer_files", "conda", "etc", "profile.d", "conda.sh")
|
conda_sh_path = os.path.join(script_dir, "installer_files", "conda", "etc", "profile.d", "conda.sh")
|
||||||
|
|
@ -270,7 +268,7 @@ def update_pytorch_and_python():
|
||||||
|
|
||||||
|
|
||||||
def clean_outdated_pytorch_cuda_dependencies():
|
def clean_outdated_pytorch_cuda_dependencies():
|
||||||
patterns = ["cu121", "cu122", "rocm6", "torch2.4", "torch2.6", "torch2.7", "torchvision", "torchaudio"]
|
patterns = ["cu121", "cu122", "torch2.4", "torch2.6", "torch2.7", "torchvision", "torchaudio"]
|
||||||
result = run_cmd("python -m pip list --format=freeze", capture_output=True, environment=True)
|
result = run_cmd("python -m pip list --format=freeze", capture_output=True, environment=True)
|
||||||
matching_packages = []
|
matching_packages = []
|
||||||
|
|
||||||
|
|
@ -316,6 +314,13 @@ def install_webui():
|
||||||
install_pytorch = get_pytorch_install_command(gpu_choice)
|
install_pytorch = get_pytorch_install_command(gpu_choice)
|
||||||
run_cmd(f"conda install -y ninja git && {install_pytorch}", assert_success=True, environment=True)
|
run_cmd(f"conda install -y ninja git && {install_pytorch}", assert_success=True, environment=True)
|
||||||
|
|
||||||
|
if gpu_choice == "INTEL":
|
||||||
|
# Install oneAPI dependencies via conda
|
||||||
|
print_big_message("Installing Intel oneAPI runtime libraries.")
|
||||||
|
run_cmd("conda install -y -c https://software.repos.intel.com/python/conda/ -c conda-forge dpcpp-cpp-rt=2024.0 mkl-dpcpp=2024.0", environment=True)
|
||||||
|
# Install libuv required by Intel-patched torch
|
||||||
|
run_cmd("conda install -y libuv", environment=True)
|
||||||
|
|
||||||
# Install the webui requirements
|
# Install the webui requirements
|
||||||
update_requirements(initial_installation=True, pull=False)
|
update_requirements(initial_installation=True, pull=False)
|
||||||
|
|
||||||
|
|
@ -358,10 +363,8 @@ def update_requirements(initial_installation=False, pull=True):
|
||||||
|
|
||||||
current_commit = get_current_commit()
|
current_commit = get_current_commit()
|
||||||
wheels_changed = not os.path.exists(state_file)
|
wheels_changed = not os.path.exists(state_file)
|
||||||
installed_wheels = set()
|
|
||||||
if not wheels_changed:
|
if not wheels_changed:
|
||||||
state = load_state()
|
state = load_state()
|
||||||
installed_wheels = set(state.get('installed_wheels', []))
|
|
||||||
if 'wheels_changed' in state or state.get('last_installed_commit') != current_commit:
|
if 'wheels_changed' in state or state.get('last_installed_commit') != current_commit:
|
||||||
wheels_changed = True
|
wheels_changed = True
|
||||||
|
|
||||||
|
|
@ -426,16 +429,8 @@ def update_requirements(initial_installation=False, pull=True):
|
||||||
|
|
||||||
# Prepare the requirements file
|
# Prepare the requirements file
|
||||||
textgen_requirements = open(requirements_file).read().splitlines()
|
textgen_requirements = open(requirements_file).read().splitlines()
|
||||||
all_whl_lines = [line.strip() for line in textgen_requirements if '.whl' in line]
|
|
||||||
|
|
||||||
if not initial_installation:
|
if not initial_installation and not wheels_changed:
|
||||||
if installed_wheels:
|
|
||||||
# Per-wheel comparison: only re-download wheels that changed
|
|
||||||
textgen_requirements = [
|
|
||||||
line for line in textgen_requirements
|
|
||||||
if '.whl' not in line or line.strip() not in installed_wheels
|
|
||||||
]
|
|
||||||
elif not wheels_changed:
|
|
||||||
textgen_requirements = [line for line in textgen_requirements if '.whl' not in line]
|
textgen_requirements = [line for line in textgen_requirements if '.whl' not in line]
|
||||||
|
|
||||||
with open('temp_requirements.txt', 'w') as file:
|
with open('temp_requirements.txt', 'w') as file:
|
||||||
|
|
@ -455,7 +450,6 @@ def update_requirements(initial_installation=False, pull=True):
|
||||||
# Save state after successful installation
|
# Save state after successful installation
|
||||||
state = load_state()
|
state = load_state()
|
||||||
state['last_installed_commit'] = current_commit
|
state['last_installed_commit'] = current_commit
|
||||||
state['installed_wheels'] = all_whl_lines
|
|
||||||
state.pop('wheels_changed', None)
|
state.pop('wheels_changed', None)
|
||||||
save_state(state)
|
save_state(state)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,21 @@
|
||||||
accelerate==1.13.*
|
accelerate==1.12.*
|
||||||
audioop-lts<1.0; python_version >= "3.13"
|
audioop-lts<1.0; python_version >= "3.13"
|
||||||
bitsandbytes==0.49.*
|
bitsandbytes==0.49.*
|
||||||
datasets
|
datasets
|
||||||
diffusers==0.37.*
|
diffusers==0.36.*
|
||||||
einops
|
einops
|
||||||
fastapi==0.112.4
|
fastapi==0.112.4
|
||||||
flash-linear-attention==0.4.*
|
flash-linear-attention==0.4.*
|
||||||
|
html2text==2025.4.15
|
||||||
huggingface-hub==1.5.*
|
huggingface-hub==1.5.*
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
mcp==1.27.0
|
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
pandas
|
pandas
|
||||||
peft==0.18.*
|
peft==0.18.*
|
||||||
Pillow>=9.5.0
|
Pillow>=9.5.0
|
||||||
pydantic==2.11.0
|
pydantic==2.11.0
|
||||||
pymupdf==1.27.*
|
pymupdf==1.27.1
|
||||||
python-docx==1.1.2
|
python-docx==1.1.2
|
||||||
pyyaml
|
pyyaml
|
||||||
requests
|
requests
|
||||||
|
|
@ -25,15 +25,14 @@ scipy
|
||||||
sentencepiece
|
sentencepiece
|
||||||
tensorboard
|
tensorboard
|
||||||
torchao==0.15.*
|
torchao==0.15.*
|
||||||
trafilatura==2.0.0
|
transformers==5.3.*
|
||||||
transformers==5.5.*
|
|
||||||
triton-windows==3.5.1.post24; platform_system == "Windows"
|
triton-windows==3.5.1.post24; platform_system == "Windows"
|
||||||
tqdm
|
tqdm
|
||||||
wandb
|
wandb
|
||||||
|
|
||||||
# Gradio
|
# Gradio
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio-4.37.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio_client-1.0.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||||
|
|
||||||
# API
|
# API
|
||||||
flask_cloudflared==0.0.15
|
flask_cloudflared==0.0.15
|
||||||
|
|
@ -41,11 +40,9 @@ sse-starlette==1.6.5
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|
||||||
# CUDA wheels
|
# CUDA wheels
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/llama_cpp_binaries-0.110.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/llama_cpp_binaries-0.110.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/ik_llama_cpp_binaries-0.110.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows"
|
https://github.com/turboderp-org/exllamav3/releases/download/v0.0.23/exllamav3-0.0.23+cu128.torch2.9.0-cp313-cp313-win_amd64.whl; platform_system == "Windows" and python_version == "3.13"
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/ik_llama_cpp_binaries-0.110.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
https://github.com/turboderp-org/exllamav3/releases/download/v0.0.23/exllamav3-0.0.23+cu128.torch2.9.0-cp313-cp313-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.13"
|
||||||
https://github.com/turboderp-org/exllamav3/releases/download/v0.0.28/exllamav3-0.0.28+cu128.torch2.9.0-cp313-cp313-win_amd64.whl; platform_system == "Windows" and python_version == "3.13"
|
|
||||||
https://github.com/turboderp-org/exllamav3/releases/download/v0.0.28/exllamav3-0.0.28+cu128.torch2.9.0-cp313-cp313-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.13"
|
|
||||||
https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp313-cp313-win_amd64.whl; platform_system == "Windows" and python_version == "3.13"
|
https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp313-cp313-win_amd64.whl; platform_system == "Windows" and python_version == "3.13"
|
||||||
https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp313-cp313-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.13"
|
https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp313-cp313-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.13"
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,19 @@
|
||||||
accelerate==1.13.*
|
accelerate==1.12.*
|
||||||
audioop-lts<1.0; python_version >= "3.13"
|
audioop-lts<1.0; python_version >= "3.13"
|
||||||
datasets
|
datasets
|
||||||
diffusers==0.37.*
|
diffusers==0.36.*
|
||||||
einops
|
einops
|
||||||
fastapi==0.112.4
|
fastapi==0.112.4
|
||||||
|
html2text==2025.4.15
|
||||||
huggingface-hub==1.5.*
|
huggingface-hub==1.5.*
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
mcp==1.27.0
|
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
pandas
|
pandas
|
||||||
peft==0.18.*
|
peft==0.18.*
|
||||||
Pillow>=9.5.0
|
Pillow>=9.5.0
|
||||||
pydantic==2.11.0
|
pydantic==2.11.0
|
||||||
pymupdf==1.27.*
|
pymupdf==1.27.1
|
||||||
python-docx==1.1.2
|
python-docx==1.1.2
|
||||||
pyyaml
|
pyyaml
|
||||||
requests
|
requests
|
||||||
|
|
@ -23,14 +23,13 @@ scipy
|
||||||
sentencepiece
|
sentencepiece
|
||||||
tensorboard
|
tensorboard
|
||||||
torchao==0.15.*
|
torchao==0.15.*
|
||||||
transformers==5.5.*
|
transformers==5.3.*
|
||||||
tqdm
|
tqdm
|
||||||
trafilatura==2.0.0
|
|
||||||
wandb
|
wandb
|
||||||
|
|
||||||
# Gradio
|
# Gradio
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio-4.37.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio_client-1.0.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||||
|
|
||||||
# API
|
# API
|
||||||
flask_cloudflared==0.0.15
|
flask_cloudflared==0.0.15
|
||||||
|
|
@ -38,5 +37,5 @@ sse-starlette==1.6.5
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|
||||||
# AMD wheels
|
# AMD wheels
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/llama_cpp_binaries-0.110.0+rocm7.2-py3-none-win_amd64.whl; platform_system == "Windows"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+rocm6.4-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/llama_cpp_binaries-0.110.0+rocm7.2-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+rocm6.4-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,19 @@
|
||||||
accelerate==1.13.*
|
accelerate==1.12.*
|
||||||
audioop-lts<1.0; python_version >= "3.13"
|
audioop-lts<1.0; python_version >= "3.13"
|
||||||
datasets
|
datasets
|
||||||
diffusers==0.37.*
|
diffusers==0.36.*
|
||||||
einops
|
einops
|
||||||
fastapi==0.112.4
|
fastapi==0.112.4
|
||||||
|
html2text==2025.4.15
|
||||||
huggingface-hub==1.5.*
|
huggingface-hub==1.5.*
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
mcp==1.27.0
|
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
pandas
|
pandas
|
||||||
peft==0.18.*
|
peft==0.18.*
|
||||||
Pillow>=9.5.0
|
Pillow>=9.5.0
|
||||||
pydantic==2.11.0
|
pydantic==2.11.0
|
||||||
pymupdf==1.27.*
|
pymupdf==1.27.1
|
||||||
python-docx==1.1.2
|
python-docx==1.1.2
|
||||||
pyyaml
|
pyyaml
|
||||||
requests
|
requests
|
||||||
|
|
@ -23,14 +23,13 @@ scipy
|
||||||
sentencepiece
|
sentencepiece
|
||||||
tensorboard
|
tensorboard
|
||||||
torchao==0.15.*
|
torchao==0.15.*
|
||||||
transformers==5.5.*
|
transformers==5.3.*
|
||||||
tqdm
|
tqdm
|
||||||
trafilatura==2.0.0
|
|
||||||
wandb
|
wandb
|
||||||
|
|
||||||
# Gradio
|
# Gradio
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio-4.37.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio_client-1.0.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||||
|
|
||||||
# API
|
# API
|
||||||
flask_cloudflared==0.0.15
|
flask_cloudflared==0.0.15
|
||||||
|
|
@ -38,4 +37,4 @@ sse-starlette==1.6.5
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|
||||||
# Mac wheels
|
# Mac wheels
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/llama_cpp_binaries-0.110.0-py3-none-macosx_13_0_x86_64.whl; platform_system == "Darwin"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0-py3-none-macosx_13_0_x86_64.whl; platform_system == "Darwin"
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,19 @@
|
||||||
accelerate==1.13.*
|
accelerate==1.12.*
|
||||||
audioop-lts<1.0; python_version >= "3.13"
|
audioop-lts<1.0; python_version >= "3.13"
|
||||||
datasets
|
datasets
|
||||||
diffusers==0.37.*
|
diffusers==0.36.*
|
||||||
einops
|
einops
|
||||||
fastapi==0.112.4
|
fastapi==0.112.4
|
||||||
|
html2text==2025.4.15
|
||||||
huggingface-hub==1.5.*
|
huggingface-hub==1.5.*
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
mcp==1.27.0
|
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
pandas
|
pandas
|
||||||
peft==0.18.*
|
peft==0.18.*
|
||||||
Pillow>=9.5.0
|
Pillow>=9.5.0
|
||||||
pydantic==2.11.0
|
pydantic==2.11.0
|
||||||
pymupdf==1.27.*
|
pymupdf==1.27.1
|
||||||
python-docx==1.1.2
|
python-docx==1.1.2
|
||||||
pyyaml
|
pyyaml
|
||||||
requests
|
requests
|
||||||
|
|
@ -23,14 +23,13 @@ scipy
|
||||||
sentencepiece
|
sentencepiece
|
||||||
tensorboard
|
tensorboard
|
||||||
torchao==0.15.*
|
torchao==0.15.*
|
||||||
transformers==5.5.*
|
transformers==5.3.*
|
||||||
tqdm
|
tqdm
|
||||||
trafilatura==2.0.0
|
|
||||||
wandb
|
wandb
|
||||||
|
|
||||||
# Gradio
|
# Gradio
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio-4.37.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio_client-1.0.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||||
|
|
||||||
# API
|
# API
|
||||||
flask_cloudflared==0.0.15
|
flask_cloudflared==0.0.15
|
||||||
|
|
@ -38,4 +37,4 @@ sse-starlette==1.6.5
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|
||||||
# Mac wheels
|
# Mac wheels
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/llama_cpp_binaries-0.110.0-py3-none-macosx_13_0_arm64.whl; platform_system == "Darwin"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0-py3-none-macosx_13_0_arm64.whl; platform_system == "Darwin"
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,19 @@
|
||||||
accelerate==1.13.*
|
accelerate==1.12.*
|
||||||
audioop-lts<1.0; python_version >= "3.13"
|
audioop-lts<1.0; python_version >= "3.13"
|
||||||
datasets
|
datasets
|
||||||
diffusers==0.37.*
|
diffusers==0.36.*
|
||||||
einops
|
einops
|
||||||
fastapi==0.112.4
|
fastapi==0.112.4
|
||||||
|
html2text==2025.4.15
|
||||||
huggingface-hub==1.5.*
|
huggingface-hub==1.5.*
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
mcp==1.27.0
|
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
pandas
|
pandas
|
||||||
peft==0.18.*
|
peft==0.18.*
|
||||||
Pillow>=9.5.0
|
Pillow>=9.5.0
|
||||||
pydantic==2.11.0
|
pydantic==2.11.0
|
||||||
pymupdf==1.27.*
|
pymupdf==1.27.1
|
||||||
python-docx==1.1.2
|
python-docx==1.1.2
|
||||||
pyyaml
|
pyyaml
|
||||||
requests
|
requests
|
||||||
|
|
@ -23,14 +23,13 @@ scipy
|
||||||
sentencepiece
|
sentencepiece
|
||||||
tensorboard
|
tensorboard
|
||||||
torchao==0.15.*
|
torchao==0.15.*
|
||||||
transformers==5.5.*
|
transformers==5.3.*
|
||||||
tqdm
|
tqdm
|
||||||
trafilatura==2.0.0
|
|
||||||
wandb
|
wandb
|
||||||
|
|
||||||
# Gradio
|
# Gradio
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio-4.37.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio_client-1.0.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||||
|
|
||||||
# API
|
# API
|
||||||
flask_cloudflared==0.0.15
|
flask_cloudflared==0.0.15
|
||||||
|
|
@ -38,7 +37,5 @@ sse-starlette==1.6.5
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|
||||||
# llama.cpp (CPU only)
|
# llama.cpp (CPU only)
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/llama_cpp_binaries-0.110.0+cpu-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cpu-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/llama_cpp_binaries-0.110.0+cpu-py3-none-win_amd64.whl; platform_system == "Windows"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cpu-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/ik_llama_cpp_binaries-0.110.0+cpu-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/ik_llama_cpp_binaries-0.110.0+cpu-py3-none-win_amd64.whl; platform_system == "Windows"
|
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,19 @@
|
||||||
accelerate==1.13.*
|
accelerate==1.12.*
|
||||||
audioop-lts<1.0; python_version >= "3.13"
|
audioop-lts<1.0; python_version >= "3.13"
|
||||||
datasets
|
datasets
|
||||||
diffusers==0.37.*
|
diffusers==0.36.*
|
||||||
einops
|
einops
|
||||||
fastapi==0.112.4
|
fastapi==0.112.4
|
||||||
|
html2text==2025.4.15
|
||||||
huggingface-hub==1.5.*
|
huggingface-hub==1.5.*
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
mcp==1.27.0
|
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
pandas
|
pandas
|
||||||
peft==0.18.*
|
peft==0.18.*
|
||||||
Pillow>=9.5.0
|
Pillow>=9.5.0
|
||||||
pydantic==2.11.0
|
pydantic==2.11.0
|
||||||
pymupdf==1.27.*
|
pymupdf==1.27.1
|
||||||
python-docx==1.1.2
|
python-docx==1.1.2
|
||||||
pyyaml
|
pyyaml
|
||||||
requests
|
requests
|
||||||
|
|
@ -23,14 +23,13 @@ scipy
|
||||||
sentencepiece
|
sentencepiece
|
||||||
tensorboard
|
tensorboard
|
||||||
torchao==0.15.*
|
torchao==0.15.*
|
||||||
transformers==5.5.*
|
transformers==5.3.*
|
||||||
tqdm
|
tqdm
|
||||||
trafilatura==2.0.0
|
|
||||||
wandb
|
wandb
|
||||||
|
|
||||||
# Gradio
|
# Gradio
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio-4.37.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio_client-1.0.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||||
|
|
||||||
# API
|
# API
|
||||||
flask_cloudflared==0.0.15
|
flask_cloudflared==0.0.15
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,21 @@
|
||||||
audioop-lts<1.0; python_version >= "3.13"
|
audioop-lts<1.0; python_version >= "3.13"
|
||||||
fastapi==0.112.4
|
fastapi==0.112.4
|
||||||
|
html2text==2025.4.15
|
||||||
huggingface-hub==1.5.*
|
huggingface-hub==1.5.*
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
mcp==1.27.0
|
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
pydantic==2.11.0
|
pydantic==2.11.0
|
||||||
pymupdf==1.27.*
|
pymupdf==1.27.1
|
||||||
python-docx==1.1.2
|
python-docx==1.1.2
|
||||||
pyyaml
|
pyyaml
|
||||||
requests
|
requests
|
||||||
rich
|
rich
|
||||||
trafilatura==2.0.0
|
|
||||||
tqdm
|
tqdm
|
||||||
|
|
||||||
# Gradio
|
# Gradio
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio-4.37.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio_client-1.0.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||||
|
|
||||||
# API
|
# API
|
||||||
flask_cloudflared==0.0.15
|
flask_cloudflared==0.0.15
|
||||||
|
|
@ -24,5 +23,5 @@ sse-starlette==1.6.5
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|
||||||
# CUDA wheels
|
# CUDA wheels
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/llama_cpp_binaries-0.110.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/llama_cpp_binaries-0.110.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,21 @@
|
||||||
audioop-lts<1.0; python_version >= "3.13"
|
audioop-lts<1.0; python_version >= "3.13"
|
||||||
fastapi==0.112.4
|
fastapi==0.112.4
|
||||||
|
html2text==2025.4.15
|
||||||
huggingface-hub==1.5.*
|
huggingface-hub==1.5.*
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
mcp==1.27.0
|
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
pydantic==2.11.0
|
pydantic==2.11.0
|
||||||
pymupdf==1.27.*
|
pymupdf==1.27.1
|
||||||
python-docx==1.1.2
|
python-docx==1.1.2
|
||||||
pyyaml
|
pyyaml
|
||||||
requests
|
requests
|
||||||
rich
|
rich
|
||||||
trafilatura==2.0.0
|
|
||||||
tqdm
|
tqdm
|
||||||
|
|
||||||
# Gradio
|
# Gradio
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio-4.37.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio_client-1.0.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||||
|
|
||||||
# API
|
# API
|
||||||
flask_cloudflared==0.0.15
|
flask_cloudflared==0.0.15
|
||||||
|
|
@ -24,5 +23,5 @@ sse-starlette==1.6.5
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|
||||||
# AMD wheels
|
# AMD wheels
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/llama_cpp_binaries-0.110.0+rocm7.2-py3-none-win_amd64.whl; platform_system == "Windows"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+rocm6.4-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/llama_cpp_binaries-0.110.0+rocm7.2-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+rocm6.4-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,21 @@
|
||||||
audioop-lts<1.0; python_version >= "3.13"
|
audioop-lts<1.0; python_version >= "3.13"
|
||||||
fastapi==0.112.4
|
fastapi==0.112.4
|
||||||
|
html2text==2025.4.15
|
||||||
huggingface-hub==1.5.*
|
huggingface-hub==1.5.*
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
mcp==1.27.0
|
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
pydantic==2.11.0
|
pydantic==2.11.0
|
||||||
pymupdf==1.27.*
|
pymupdf==1.27.1
|
||||||
python-docx==1.1.2
|
python-docx==1.1.2
|
||||||
pyyaml
|
pyyaml
|
||||||
requests
|
requests
|
||||||
rich
|
rich
|
||||||
trafilatura==2.0.0
|
|
||||||
tqdm
|
tqdm
|
||||||
|
|
||||||
# Gradio
|
# Gradio
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio-4.37.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio_client-1.0.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||||
|
|
||||||
# API
|
# API
|
||||||
flask_cloudflared==0.0.15
|
flask_cloudflared==0.0.15
|
||||||
|
|
@ -24,4 +23,4 @@ sse-starlette==1.6.5
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|
||||||
# Mac wheels
|
# Mac wheels
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/llama_cpp_binaries-0.110.0-py3-none-macosx_13_0_x86_64.whl; platform_system == "Darwin"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0-py3-none-macosx_13_0_x86_64.whl; platform_system == "Darwin"
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,21 @@
|
||||||
audioop-lts<1.0; python_version >= "3.13"
|
audioop-lts<1.0; python_version >= "3.13"
|
||||||
fastapi==0.112.4
|
fastapi==0.112.4
|
||||||
|
html2text==2025.4.15
|
||||||
huggingface-hub==1.5.*
|
huggingface-hub==1.5.*
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
mcp==1.27.0
|
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
pydantic==2.11.0
|
pydantic==2.11.0
|
||||||
pymupdf==1.27.*
|
pymupdf==1.27.1
|
||||||
python-docx==1.1.2
|
python-docx==1.1.2
|
||||||
pyyaml
|
pyyaml
|
||||||
requests
|
requests
|
||||||
rich
|
rich
|
||||||
trafilatura==2.0.0
|
|
||||||
tqdm
|
tqdm
|
||||||
|
|
||||||
# Gradio
|
# Gradio
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio-4.37.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio_client-1.0.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||||
|
|
||||||
# API
|
# API
|
||||||
flask_cloudflared==0.0.15
|
flask_cloudflared==0.0.15
|
||||||
|
|
@ -24,4 +23,4 @@ sse-starlette==1.6.5
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|
||||||
# Mac wheels
|
# Mac wheels
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/llama_cpp_binaries-0.110.0-py3-none-macosx_13_0_arm64.whl; platform_system == "Darwin"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0-py3-none-macosx_13_0_arm64.whl; platform_system == "Darwin"
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,21 @@
|
||||||
audioop-lts<1.0; python_version >= "3.13"
|
audioop-lts<1.0; python_version >= "3.13"
|
||||||
fastapi==0.112.4
|
fastapi==0.112.4
|
||||||
|
html2text==2025.4.15
|
||||||
huggingface-hub==1.5.*
|
huggingface-hub==1.5.*
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
mcp==1.27.0
|
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
pydantic==2.11.0
|
pydantic==2.11.0
|
||||||
pymupdf==1.27.*
|
pymupdf==1.27.1
|
||||||
python-docx==1.1.2
|
python-docx==1.1.2
|
||||||
pyyaml
|
pyyaml
|
||||||
requests
|
requests
|
||||||
rich
|
rich
|
||||||
trafilatura==2.0.0
|
|
||||||
tqdm
|
tqdm
|
||||||
|
|
||||||
# Gradio
|
# Gradio
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio-4.37.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio_client-1.0.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||||
|
|
||||||
# API
|
# API
|
||||||
flask_cloudflared==0.0.15
|
flask_cloudflared==0.0.15
|
||||||
|
|
@ -24,5 +23,5 @@ sse-starlette==1.6.5
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|
||||||
# llama.cpp (CPU only)
|
# llama.cpp (CPU only)
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/llama_cpp_binaries-0.110.0+cpu-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cpu-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/llama_cpp_binaries-0.110.0+cpu-py3-none-win_amd64.whl; platform_system == "Windows"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cpu-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,21 @@
|
||||||
audioop-lts<1.0; python_version >= "3.13"
|
audioop-lts<1.0; python_version >= "3.13"
|
||||||
fastapi==0.112.4
|
fastapi==0.112.4
|
||||||
|
html2text==2025.4.15
|
||||||
huggingface-hub==1.5.*
|
huggingface-hub==1.5.*
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
mcp==1.27.0
|
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
pydantic==2.11.0
|
pydantic==2.11.0
|
||||||
pymupdf==1.27.*
|
pymupdf==1.27.1
|
||||||
python-docx==1.1.2
|
python-docx==1.1.2
|
||||||
pyyaml
|
pyyaml
|
||||||
requests
|
requests
|
||||||
rich
|
rich
|
||||||
trafilatura==2.0.0
|
|
||||||
tqdm
|
tqdm
|
||||||
|
|
||||||
# Gradio
|
# Gradio
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio-4.37.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio-4.37.2+custom.9-py3-none-any.whl
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio_client-1.0.2+custom.19-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.9/gradio_client-1.0.2+custom.9-py3-none-any.whl
|
||||||
|
|
||||||
# API
|
# API
|
||||||
flask_cloudflared==0.0.15
|
flask_cloudflared==0.0.15
|
||||||
|
|
@ -24,5 +23,5 @@ sse-starlette==1.6.5
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|
||||||
# CUDA wheels
|
# CUDA wheels
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/llama_cpp_binaries-0.110.0+cu131-py3-none-win_amd64.whl; platform_system == "Windows"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu131-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/llama_cpp_binaries-0.110.0+cu131-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.87.0/llama_cpp_binaries-0.87.0+cu131-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
||||||
|
|
|
||||||
|
|
@ -1,28 +0,0 @@
|
||||||
audioop-lts<1.0; python_version >= "3.13"
|
|
||||||
fastapi==0.112.4
|
|
||||||
huggingface-hub==1.5.*
|
|
||||||
jinja2==3.1.6
|
|
||||||
markdown
|
|
||||||
mcp==1.27.0
|
|
||||||
numpy==2.2.*
|
|
||||||
pydantic==2.11.0
|
|
||||||
pymupdf==1.27.*
|
|
||||||
python-docx==1.1.2
|
|
||||||
pyyaml
|
|
||||||
requests
|
|
||||||
rich
|
|
||||||
trafilatura==2.0.0
|
|
||||||
tqdm
|
|
||||||
|
|
||||||
# Gradio
|
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio-4.37.2+custom.19-py3-none-any.whl
|
|
||||||
https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.19/gradio_client-1.0.2+custom.19-py3-none-any.whl
|
|
||||||
|
|
||||||
# API
|
|
||||||
flask_cloudflared==0.0.15
|
|
||||||
sse-starlette==1.6.5
|
|
||||||
tiktoken
|
|
||||||
|
|
||||||
# CUDA wheels
|
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/ik_llama_cpp_binaries-0.110.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows"
|
|
||||||
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.110.0/ik_llama_cpp_binaries-0.110.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue