Compare commits

..

No commits in common. "main" and "v3.22" have entirely different histories.
main ... v3.22

151 changed files with 6577 additions and 6092 deletions

View file

@ -67,4 +67,4 @@ jobs:
uses: ./.github/workflows/build-portable-release.yml uses: ./.github/workflows/build-portable-release.yml
with: with:
version: ${{ inputs.version }} version: ${{ inputs.version }}
config: 'os:macos-15-intel,macos-14' config: 'os:macos-13,macos-14'

View file

@ -58,8 +58,9 @@ jobs:
run: | run: |
$matrix = @{ $matrix = @{
'os' = @('ubuntu-22.04', 'windows-2022') 'os' = @('ubuntu-22.04', 'windows-2022')
'pyver' = @("3.13") 'pyver' = @("3.11")
'cuda' = @("12.4", "13.1") 'avx' = @("AVX2")
'cuda' = @("12.4")
} }
if ($env:CONFIGIN -ne 'Default') {$env:CONFIGIN.split(';').foreach({$matrix[$_.split(':')[0]] = $_.split(':')[1].split(',')})} if ($env:CONFIGIN -ne 'Default') {$env:CONFIGIN.split(';').foreach({$matrix[$_.split(':')[0]] = $_.split(':')[1].split(',')})}
@ -74,7 +75,7 @@ jobs:
Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT
build_wheels: build_wheels:
name: ${{ matrix.os }} ${{ matrix.pyver }} CUDA ${{ matrix.cuda }} name: ${{ matrix.os }} ${{ matrix.pyver }} CPU ${{ matrix.avx }} CUDA ${{ matrix.cuda }}
needs: define_matrix needs: define_matrix
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
@ -83,16 +84,17 @@ jobs:
run: run:
shell: pwsh shell: pwsh
env: env:
AVXVER: ${{ matrix.avx }}
PCKGVER: ${{ inputs.version }} PCKGVER: ${{ inputs.version }}
steps: steps:
- uses: actions/checkout@v6 - uses: actions/checkout@v4
with: with:
repository: 'oobabooga/text-generation-webui' repository: 'oobabooga/text-generation-webui'
ref: ${{ inputs.version }} ref: ${{ inputs.version }}
submodules: 'recursive' submodules: 'recursive'
- uses: actions/setup-python@v6 - uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.pyver }} python-version: ${{ matrix.pyver }}
@ -111,20 +113,21 @@ jobs:
# Define common variables # Define common variables
CUDA_VERSION="${{ matrix.cuda }}" CUDA_VERSION="${{ matrix.cuda }}"
AVX_SUPPORT="${{ matrix.avx }}"
VERSION="${{ inputs.version }}" VERSION="${{ inputs.version }}"
# 1. Set platform-specific variables # 1. Set platform-specific variables
if [[ "$RUNNER_OS" == "Windows" ]]; then if [[ "$RUNNER_OS" == "Windows" ]]; then
PLATFORM="windows" PLATFORM="windows"
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-pc-windows-msvc-install_only.tar.gz" PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20250409/cpython-3.11.12+20250409-x86_64-pc-windows-msvc-install_only.tar.gz"
PIP_PATH="portable_env/python.exe -m pip" PIP_PATH="portable_env/python.exe -m pip"
PACKAGES_PATH="portable_env/Lib/site-packages" PACKAGES_PATH="portable_env/Lib/site-packages"
rm start_linux.sh start_macos.sh rm start_linux.sh start_macos.sh
else else
PLATFORM="linux" PLATFORM="linux"
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-unknown-linux-gnu-install_only.tar.gz" PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20250409/cpython-3.11.12+20250409-x86_64-unknown-linux-gnu-install_only.tar.gz"
PIP_PATH="portable_env/bin/python -m pip" PIP_PATH="portable_env/bin/python -m pip"
PACKAGES_PATH="portable_env/lib/python3.13/site-packages" PACKAGES_PATH="portable_env/lib/python3.11/site-packages"
rm start_macos.sh start_windows.bat rm start_macos.sh start_windows.bat
fi fi
@ -135,14 +138,17 @@ jobs:
tar -xzf python-build.tar.gz tar -xzf python-build.tar.gz
mv python "text-generation-webui-${VERSION_CLEAN}/portable_env" mv python "text-generation-webui-${VERSION_CLEAN}/portable_env"
# 3. Prepare requirements file based on CUDA version # 3. Prepare requirements file based on AVX and CUDA
cd "text-generation-webui-${VERSION_CLEAN}" if [[ "$AVX_SUPPORT" == "AVX2" ]]; then
if [[ "$CUDA_VERSION" == "13.1" ]]; then BASE_REQ_FILE="requirements/portable/requirements.txt"
REQ_FILE="requirements/portable/requirements_cuda131.txt"
else else
REQ_FILE="requirements/portable/requirements.txt" BASE_REQ_FILE="requirements/portable/requirements_noavx2.txt"
fi fi
# Create CUDA-specific requirements file if needed
cd "text-generation-webui-${VERSION_CLEAN}"
REQ_FILE="$BASE_REQ_FILE"
# 4. Install packages # 4. Install packages
echo "Installing Python packages from $REQ_FILE..." echo "Installing Python packages from $REQ_FILE..."
$PIP_PATH install --target="./$PACKAGES_PATH" -r "$REQ_FILE" $PIP_PATH install --target="./$PACKAGES_PATH" -r "$REQ_FILE"
@ -150,16 +156,15 @@ jobs:
# 5. Clean up # 5. Clean up
rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py
# 6. Create archive # 6. Create ZIP file
cd .. cd ..
ZIP_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-cuda${CUDA_VERSION}.zip"
echo "Creating archive: $ZIP_NAME"
if [[ "$RUNNER_OS" == "Windows" ]]; then if [[ "$RUNNER_OS" == "Windows" ]]; then
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-cuda${CUDA_VERSION}.zip" powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ZIP_NAME"
echo "Creating archive: $ARCHIVE_NAME"
powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ARCHIVE_NAME"
else else
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-cuda${CUDA_VERSION}.tar.gz" zip -r "$ZIP_NAME" "text-generation-webui-${VERSION_CLEAN}"
echo "Creating archive: $ARCHIVE_NAME"
tar czf "$ARCHIVE_NAME" "text-generation-webui-${VERSION_CLEAN}"
fi fi
- name: Upload files to a GitHub release - name: Upload files to a GitHub release
@ -168,7 +173,7 @@ jobs:
continue-on-error: true continue-on-error: true
with: with:
repo_token: ${{ secrets.GITHUB_TOKEN }} repo_token: ${{ secrets.GITHUB_TOKEN }}
file: ../textgen-portable-* file: ../textgen-portable-*.zip
tag: ${{ inputs.version }} tag: ${{ inputs.version }}
file_glob: true file_glob: true
make_latest: false make_latest: false

View file

@ -57,8 +57,9 @@ jobs:
id: set-matrix id: set-matrix
run: | run: |
$matrix = @{ $matrix = @{
'os' = @('ubuntu-22.04', 'windows-2022') 'os' = @('ubuntu-22.04')
'pyver' = @("3.13") 'pyver' = @("3.11")
'avx' = @("AVX2")
} }
if ($env:CONFIGIN -ne 'Default') {$env:CONFIGIN.split(';').foreach({$matrix[$_.split(':')[0]] = $_.split(':')[1].split(',')})} if ($env:CONFIGIN -ne 'Default') {$env:CONFIGIN.split(';').foreach({$matrix[$_.split(':')[0]] = $_.split(':')[1].split(',')})}
@ -73,7 +74,7 @@ jobs:
Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT
build_wheels: build_wheels:
name: ${{ matrix.os }} ${{ matrix.pyver }} name: ${{ matrix.os }} ${{ matrix.pyver }} CPU ${{ matrix.avx }}
needs: define_matrix needs: define_matrix
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
@ -82,16 +83,17 @@ jobs:
run: run:
shell: pwsh shell: pwsh
env: env:
AVXVER: ${{ matrix.avx }}
PCKGVER: ${{ inputs.version }} PCKGVER: ${{ inputs.version }}
steps: steps:
- uses: actions/checkout@v6 - uses: actions/checkout@v4
with: with:
repository: 'oobabooga/text-generation-webui' repository: 'oobabooga/text-generation-webui'
ref: ${{ inputs.version }} ref: ${{ inputs.version }}
submodules: 'recursive' submodules: 'recursive'
- uses: actions/setup-python@v6 - uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.pyver }} python-version: ${{ matrix.pyver }}
@ -109,22 +111,15 @@ jobs:
find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf
# Define common variables # Define common variables
AVX_SUPPORT="${{ matrix.avx }}"
VERSION="${{ inputs.version }}" VERSION="${{ inputs.version }}"
# 1. Set platform-specific variables # 1. Set platform-specific variables (Linux only for ROCm)
if [[ "$RUNNER_OS" == "Windows" ]]; then
PLATFORM="windows"
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-pc-windows-msvc-install_only.tar.gz"
PIP_PATH="portable_env/python.exe -m pip"
PACKAGES_PATH="portable_env/Lib/site-packages"
rm start_linux.sh start_macos.sh
else
PLATFORM="linux" PLATFORM="linux"
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-unknown-linux-gnu-install_only.tar.gz" PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20250409/cpython-3.11.12+20250409-x86_64-unknown-linux-gnu-install_only.tar.gz"
PIP_PATH="portable_env/bin/python -m pip" PIP_PATH="portable_env/bin/python -m pip"
PACKAGES_PATH="portable_env/lib/python3.13/site-packages" PACKAGES_PATH="portable_env/lib/python3.11/site-packages"
rm start_macos.sh start_windows.bat rm start_macos.sh start_windows.bat
fi
# 2. Download and extract Python # 2. Download and extract Python
cd .. cd ..
@ -133,8 +128,13 @@ jobs:
tar -xzf python-build.tar.gz tar -xzf python-build.tar.gz
mv python "text-generation-webui-${VERSION_CLEAN}/portable_env" mv python "text-generation-webui-${VERSION_CLEAN}/portable_env"
# 3. Prepare requirements file # 3. Prepare requirements file based on AVX
REQ_FILE="requirements/portable/requirements_amd.txt" if [[ "$AVX_SUPPORT" == "AVX2" ]]; then
BASE_REQ_FILE="requirements/portable/requirements_amd.txt"
else
BASE_REQ_FILE="requirements/portable/requirements_amd_noavx2.txt"
fi
REQ_FILE="$BASE_REQ_FILE"
cd "text-generation-webui-${VERSION_CLEAN}" cd "text-generation-webui-${VERSION_CLEAN}"
@ -145,17 +145,12 @@ jobs:
# 5. Clean up # 5. Clean up
rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py
# 6. Create archive # 6. Create ZIP file
cd .. cd ..
if [[ "$RUNNER_OS" == "Windows" ]]; then ZIP_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm.zip"
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm7.2.zip" echo "Creating archive: $ZIP_NAME"
echo "Creating archive: $ARCHIVE_NAME"
powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ARCHIVE_NAME" zip -r "$ZIP_NAME" "text-generation-webui-${VERSION_CLEAN}"
else
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm7.2.tar.gz"
echo "Creating archive: $ARCHIVE_NAME"
tar czf "$ARCHIVE_NAME" "text-generation-webui-${VERSION_CLEAN}"
fi
- name: Upload files to a GitHub release - name: Upload files to a GitHub release
id: upload-release id: upload-release
@ -163,7 +158,7 @@ jobs:
continue-on-error: true continue-on-error: true
with: with:
repo_token: ${{ secrets.GITHUB_TOKEN }} repo_token: ${{ secrets.GITHUB_TOKEN }}
file: ../textgen-portable-* file: ../textgen-portable-*.zip
tag: ${{ inputs.version }} tag: ${{ inputs.version }}
file_glob: true file_glob: true
make_latest: false make_latest: false

View file

@ -58,7 +58,8 @@ jobs:
run: | run: |
$matrix = @{ $matrix = @{
'os' = @('ubuntu-22.04', 'windows-2022') 'os' = @('ubuntu-22.04', 'windows-2022')
'pyver' = @("3.13") 'pyver' = @("3.11")
'avx' = @("AVX2")
} }
if ($env:CONFIGIN -ne 'Default') {$env:CONFIGIN.split(';').foreach({$matrix[$_.split(':')[0]] = $_.split(':')[1].split(',')})} if ($env:CONFIGIN -ne 'Default') {$env:CONFIGIN.split(';').foreach({$matrix[$_.split(':')[0]] = $_.split(':')[1].split(',')})}
@ -73,7 +74,7 @@ jobs:
Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT
build_wheels: build_wheels:
name: ${{ matrix.os }} ${{ matrix.pyver }} name: ${{ matrix.os }} ${{ matrix.pyver }} CPU ${{ matrix.avx }}
needs: define_matrix needs: define_matrix
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
@ -82,16 +83,17 @@ jobs:
run: run:
shell: pwsh shell: pwsh
env: env:
AVXVER: ${{ matrix.avx }}
PCKGVER: ${{ inputs.version }} PCKGVER: ${{ inputs.version }}
steps: steps:
- uses: actions/checkout@v6 - uses: actions/checkout@v4
with: with:
repository: 'oobabooga/text-generation-webui' repository: 'oobabooga/text-generation-webui'
ref: ${{ inputs.version }} ref: ${{ inputs.version }}
submodules: 'recursive' submodules: 'recursive'
- uses: actions/setup-python@v6 - uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.pyver }} python-version: ${{ matrix.pyver }}
@ -109,20 +111,21 @@ jobs:
find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf
# Define common variables # Define common variables
AVX_SUPPORT="${{ matrix.avx }}"
VERSION="${{ inputs.version }}" VERSION="${{ inputs.version }}"
# 1. Set platform-specific variables # 1. Set platform-specific variables
if [[ "$RUNNER_OS" == "Windows" ]]; then if [[ "$RUNNER_OS" == "Windows" ]]; then
PLATFORM="windows" PLATFORM="windows"
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-pc-windows-msvc-install_only.tar.gz" PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20250409/cpython-3.11.12+20250409-x86_64-pc-windows-msvc-install_only.tar.gz"
PIP_PATH="portable_env/python.exe -m pip" PIP_PATH="portable_env/python.exe -m pip"
PACKAGES_PATH="portable_env/Lib/site-packages" PACKAGES_PATH="portable_env/Lib/site-packages"
rm start_linux.sh start_macos.sh rm start_linux.sh start_macos.sh
else else
PLATFORM="linux" PLATFORM="linux"
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-unknown-linux-gnu-install_only.tar.gz" PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20250409/cpython-3.11.12+20250409-x86_64-unknown-linux-gnu-install_only.tar.gz"
PIP_PATH="portable_env/bin/python -m pip" PIP_PATH="portable_env/bin/python -m pip"
PACKAGES_PATH="portable_env/lib/python3.13/site-packages" PACKAGES_PATH="portable_env/lib/python3.11/site-packages"
rm start_macos.sh start_windows.bat rm start_macos.sh start_windows.bat
fi fi
@ -133,8 +136,13 @@ jobs:
tar -xzf python-build.tar.gz tar -xzf python-build.tar.gz
mv python "text-generation-webui-${VERSION_CLEAN}/portable_env" mv python "text-generation-webui-${VERSION_CLEAN}/portable_env"
# 3. Prepare requirements file # 3. Prepare requirements file based on AVX
REQ_FILE="requirements/portable/requirements_vulkan.txt" if [[ "$AVX_SUPPORT" == "AVX2" ]]; then
BASE_REQ_FILE="requirements/portable/requirements_vulkan.txt"
else
BASE_REQ_FILE="requirements/portable/requirements_vulkan_noavx2.txt"
fi
REQ_FILE="$BASE_REQ_FILE"
cd "text-generation-webui-${VERSION_CLEAN}" cd "text-generation-webui-${VERSION_CLEAN}"
@ -145,16 +153,15 @@ jobs:
# 5. Clean up # 5. Clean up
rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py
# 6. Create archive # 6. Create ZIP file
cd .. cd ..
ZIP_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-vulkan.zip"
echo "Creating archive: $ZIP_NAME"
if [[ "$RUNNER_OS" == "Windows" ]]; then if [[ "$RUNNER_OS" == "Windows" ]]; then
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-vulkan.zip" powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ZIP_NAME"
echo "Creating archive: $ARCHIVE_NAME"
powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ARCHIVE_NAME"
else else
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-vulkan.tar.gz" zip -r "$ZIP_NAME" "text-generation-webui-${VERSION_CLEAN}"
echo "Creating archive: $ARCHIVE_NAME"
tar czf "$ARCHIVE_NAME" "text-generation-webui-${VERSION_CLEAN}"
fi fi
- name: Upload files to a GitHub release - name: Upload files to a GitHub release
@ -163,7 +170,7 @@ jobs:
continue-on-error: true continue-on-error: true
with: with:
repo_token: ${{ secrets.GITHUB_TOKEN }} repo_token: ${{ secrets.GITHUB_TOKEN }}
file: ../textgen-portable-* file: ../textgen-portable-*.zip
tag: ${{ inputs.version }} tag: ${{ inputs.version }}
file_glob: true file_glob: true
make_latest: false make_latest: false

View file

@ -58,7 +58,8 @@ jobs:
run: | run: |
$matrix = @{ $matrix = @{
'os' = @('ubuntu-22.04', 'windows-2022', 'macos-14') 'os' = @('ubuntu-22.04', 'windows-2022', 'macos-14')
'pyver' = @("3.13") 'pyver' = @("3.11")
'avx' = @("AVX2")
} }
if ($env:CONFIGIN -ne 'Default') {$env:CONFIGIN.split(';').foreach({$matrix[$_.split(':')[0]] = $_.split(':')[1].split(',')})} if ($env:CONFIGIN -ne 'Default') {$env:CONFIGIN.split(';').foreach({$matrix[$_.split(':')[0]] = $_.split(':')[1].split(',')})}
@ -73,7 +74,7 @@ jobs:
Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT
build_wheels: build_wheels:
name: ${{ matrix.os }} ${{ matrix.pyver }} name: ${{ matrix.os }} ${{ matrix.pyver }} CPU ${{ matrix.avx }}
needs: define_matrix needs: define_matrix
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
@ -82,16 +83,17 @@ jobs:
run: run:
shell: pwsh shell: pwsh
env: env:
AVXVER: ${{ matrix.avx }}
PCKGVER: ${{ inputs.version }} PCKGVER: ${{ inputs.version }}
steps: steps:
- uses: actions/checkout@v6 - uses: actions/checkout@v4
with: with:
repository: 'oobabooga/text-generation-webui' repository: 'oobabooga/text-generation-webui'
ref: ${{ inputs.version }} ref: ${{ inputs.version }}
submodules: 'recursive' submodules: 'recursive'
- uses: actions/setup-python@v6 - uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.pyver }} python-version: ${{ matrix.pyver }}
@ -109,35 +111,36 @@ jobs:
find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf
# Define common variables # Define common variables
AVX_SUPPORT="${{ matrix.avx }}"
VERSION="${{ inputs.version }}" VERSION="${{ inputs.version }}"
OS_TYPE="${{ matrix.os }}" OS_TYPE="${{ matrix.os }}"
# 1. Set platform-specific variables # 1. Set platform-specific variables
if [[ "$RUNNER_OS" == "Windows" ]]; then if [[ "$RUNNER_OS" == "Windows" ]]; then
PLATFORM="windows-cpu" PLATFORM="windows-cpu"
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-pc-windows-msvc-install_only.tar.gz" PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20250409/cpython-3.11.12+20250409-x86_64-pc-windows-msvc-install_only.tar.gz"
PIP_PATH="portable_env/python.exe -m pip" PIP_PATH="portable_env/python.exe -m pip"
PACKAGES_PATH="portable_env/Lib/site-packages" PACKAGES_PATH="portable_env/Lib/site-packages"
rm start_linux.sh start_macos.sh rm start_linux.sh start_macos.sh
elif [[ "$RUNNER_OS" == "macOS" ]]; then elif [[ "$RUNNER_OS" == "macOS" ]]; then
if [[ "$OS_TYPE" == "macos-15-intel" ]]; then if [[ "$OS_TYPE" == "macos-13" ]]; then
PLATFORM="macos-x86_64" PLATFORM="macos-x86_64"
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-apple-darwin-install_only.tar.gz" PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20250409/cpython-3.11.12+20250409-x86_64-apple-darwin-install_only.tar.gz"
REQ_TYPE="apple_intel" REQ_TYPE="apple_intel"
else else
PLATFORM="macos-arm64" PLATFORM="macos-arm64"
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-aarch64-apple-darwin-install_only.tar.gz" PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20250409/cpython-3.11.12+20250409-aarch64-apple-darwin-install_only.tar.gz"
REQ_TYPE="apple_silicon" REQ_TYPE="apple_silicon"
fi fi
PIP_PATH="portable_env/bin/python -m pip" PIP_PATH="portable_env/bin/python -m pip"
PACKAGES_PATH="portable_env/lib/python3.13/site-packages" PACKAGES_PATH="portable_env/lib/python3.11/site-packages"
rm start_linux.sh start_windows.bat rm start_linux.sh start_windows.bat
else else
# Linux case # Linux case
PLATFORM="linux-cpu" PLATFORM="linux-cpu"
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-unknown-linux-gnu-install_only.tar.gz" PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20250409/cpython-3.11.12+20250409-x86_64-unknown-linux-gnu-install_only.tar.gz"
PIP_PATH="portable_env/bin/python -m pip" PIP_PATH="portable_env/bin/python -m pip"
PACKAGES_PATH="portable_env/lib/python3.13/site-packages" PACKAGES_PATH="portable_env/lib/python3.11/site-packages"
rm start_macos.sh start_windows.bat rm start_macos.sh start_windows.bat
fi fi
@ -148,18 +151,23 @@ jobs:
tar -xzf python-build.tar.gz tar -xzf python-build.tar.gz
mv python "text-generation-webui-${VERSION_CLEAN}/portable_env" mv python "text-generation-webui-${VERSION_CLEAN}/portable_env"
# 3. Prepare requirements file based on platform # 3. Prepare requirements file based on platform and AVX
cd "text-generation-webui-${VERSION_CLEAN}" cd "text-generation-webui-${VERSION_CLEAN}"
# Select requirements file based on platform # Select requirements file based on platform
if [[ "$RUNNER_OS" == "macOS" ]]; then if [[ "$RUNNER_OS" == "macOS" ]]; then
if [[ "$OS_TYPE" == "macos-15-intel" ]]; then if [[ "$OS_TYPE" == "macos-13" ]]; then
REQ_FILE="requirements/portable/requirements_apple_intel.txt" REQ_FILE="requirements/portable/requirements_apple_intel.txt"
else else
REQ_FILE="requirements/portable/requirements_apple_silicon.txt" REQ_FILE="requirements/portable/requirements_apple_silicon.txt"
fi fi
else else
# For Windows and Linux, check AVX support
if [[ "$AVX_SUPPORT" == "AVX2" ]]; then
REQ_FILE="requirements/portable/requirements_cpu_only.txt" REQ_FILE="requirements/portable/requirements_cpu_only.txt"
else
REQ_FILE="requirements/portable/requirements_cpu_only_noavx2.txt"
fi
fi fi
echo "Using requirements file: $REQ_FILE" echo "Using requirements file: $REQ_FILE"
@ -171,16 +179,15 @@ jobs:
# 5. Clean up # 5. Clean up
rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py
# 6. Create archive # 6. Create ZIP file
cd .. cd ..
ZIP_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}.zip"
echo "Creating archive: $ZIP_NAME"
if [[ "$RUNNER_OS" == "Windows" ]]; then if [[ "$RUNNER_OS" == "Windows" ]]; then
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}.zip" powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ZIP_NAME"
echo "Creating archive: $ARCHIVE_NAME"
powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ARCHIVE_NAME"
else else
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}.tar.gz" zip -r "$ZIP_NAME" "text-generation-webui-${VERSION_CLEAN}"
echo "Creating archive: $ARCHIVE_NAME"
tar czf "$ARCHIVE_NAME" "text-generation-webui-${VERSION_CLEAN}"
fi fi
- name: Upload files to a GitHub release - name: Upload files to a GitHub release
@ -189,7 +196,7 @@ jobs:
continue-on-error: true continue-on-error: true
with: with:
repo_token: ${{ secrets.GITHUB_TOKEN }} repo_token: ${{ secrets.GITHUB_TOKEN }}
file: ../textgen-portable-* file: ../textgen-portable-*.zip
tag: ${{ inputs.version }} tag: ${{ inputs.version }}
file_glob: true file_glob: true
make_latest: false make_latest: false

View file

@ -51,7 +51,7 @@
"source": [ "source": [
"#@title 2. Launch the web UI\n", "#@title 2. Launch the web UI\n",
"\n", "\n",
"#@markdown You can provide a direct GGUF link or a Hugging Face model URL.\n", "#@markdown If unsure about the branch, write \"main\" or leave it blank.\n",
"\n", "\n",
"import os\n", "import os\n",
"from pathlib import Path\n", "from pathlib import Path\n",
@ -72,9 +72,9 @@
" ./start_linux.sh\n", " ./start_linux.sh\n",
"\n", "\n",
"# Parameters\n", "# Parameters\n",
"model_url = \"https://huggingface.co/unsloth/Qwen3.5-9B-GGUF/resolve/main/Qwen3.5-9B-Q4_K_M.gguf\" #@param {type:\"string\"}\n", "model_url = \"https://huggingface.co/turboderp/gemma-2-9b-it-exl2\" #@param {type:\"string\"}\n",
"branch = \"\" #@param {type:\"string\"}\n", "branch = \"8.0bpw\" #@param {type:\"string\"}\n",
"command_line_flags = \"--load-in-4bit --use_double_quant\" #@param {type:\"string\"}\n", "command_line_flags = \"--n-gpu-layers 128 --load-in-4bit --use_double_quant --no_flash_attn\" #@param {type:\"string\"}\n",
"api = False #@param {type:\"boolean\"}\n", "api = False #@param {type:\"boolean\"}\n",
"\n", "\n",
"if api:\n", "if api:\n",
@ -83,28 +83,26 @@
" command_line_flags += f\" {param}\"\n", " command_line_flags += f\" {param}\"\n",
"\n", "\n",
"model_url = model_url.strip()\n", "model_url = model_url.strip()\n",
"model_name = \"\"\n",
"if model_url != \"\":\n", "if model_url != \"\":\n",
" if not model_url.startswith('http'):\n", " if not model_url.startswith('http'):\n",
" model_url = 'https://huggingface.co/' + model_url\n", " model_url = 'https://huggingface.co/' + model_url\n",
"\n", "\n",
" branch = branch.strip()\n", " # Download the model\n",
" if '/resolve/' in model_url:\n", " url_parts = model_url.strip('/').strip().split('/')\n",
" model_name = model_url.split('?')[0].split('/')[-1]\n", " output_folder = f\"{url_parts[-2]}_{url_parts[-1]}\"\n",
" !python download-model.py {model_url}\n", " branch = branch.strip('\"\\' ')\n",
" else:\n", " if branch.strip() not in ['', 'main']:\n",
" url_parts = model_url.strip('/').split('/')\n", " output_folder += f\"_{branch}\"\n",
" model_name = f\"{url_parts[-2]}_{url_parts[-1]}\"\n",
" if branch not in ['', 'main']:\n",
" model_name += f\"_{branch}\"\n",
" !python download-model.py {model_url} --branch {branch}\n", " !python download-model.py {model_url} --branch {branch}\n",
" else:\n", " else:\n",
" !python download-model.py {model_url}\n", " !python download-model.py {model_url}\n",
"else:\n",
" output_folder = \"\"\n",
"\n", "\n",
"# Start the web UI\n", "# Start the web UI\n",
"cmd = f\"./start_linux.sh {command_line_flags} --share\"\n", "cmd = f\"./start_linux.sh {command_line_flags} --share\"\n",
"if model_name != \"\":\n", "if output_folder != \"\":\n",
" cmd += f\" --model {model_name}\"\n", " cmd += f\" --model {output_folder}\"\n",
"\n", "\n",
"!$cmd" "!$cmd"
], ],

265
README.md
View file

@ -13,7 +13,7 @@
# Text Generation Web UI # Text Generation Web UI
A Gradio web UI for running Large Language Models locally. 100% private and offline. Supports text generation, vision, tool-calling, training, image generation, and more. A Gradio web UI for Large Language Models.
[Try the Deep Reason extension](https://oobabooga.gumroad.com/l/deep_reason) [Try the Deep Reason extension](https://oobabooga.gumroad.com/l/deep_reason)
@ -21,23 +21,29 @@ A Gradio web UI for running Large Language Models locally. 100% private and offl
|:---:|:---:| |:---:|:---:|
|![Image1](https://github.com/oobabooga/screenshots/raw/main/DEFAULT-3.5.png) | ![Image2](https://github.com/oobabooga/screenshots/raw/main/PARAMETERS-3.5.png) | |![Image1](https://github.com/oobabooga/screenshots/raw/main/DEFAULT-3.5.png) | ![Image2](https://github.com/oobabooga/screenshots/raw/main/PARAMETERS-3.5.png) |
## 🔥 News
- The project now supports **image generation**! Including Z-Image-Turbo, 4bit/8bit quantization, `torch.compile`, and LLM-generated prompt variations ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Image-Generation-Tutorial)).
## Features ## Features
- **Multiple backends**: [llama.cpp](https://github.com/ggerganov/llama.cpp), [Transformers](https://github.com/huggingface/transformers), [ExLlamaV3](https://github.com/turboderp-org/exllamav3), and [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM). Switch between backends and models without restarting. - Supports multiple local text generation backends, including [llama.cpp](https://github.com/ggerganov/llama.cpp), [Transformers](https://github.com/huggingface/transformers), [ExLlamaV3](https://github.com/turboderp-org/exllamav3), [ExLlamaV2](https://github.com/turboderp-org/exllamav2), and [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) (the latter via its own [Dockerfile](https://github.com/oobabooga/text-generation-webui/blob/main/docker/TensorRT-LLM/Dockerfile)).
- **File attachments**: Upload text files, PDF documents, and .docx documents to talk about their contents. - Easy setup: Choose between **portable builds** (zero setup, just unzip and run) for GGUF models on Windows/Linux/macOS, or the one-click installer that creates a self-contained `installer_files` directory.
- **Vision (multimodal)**: Attach images to messages for visual understanding ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Multimodal-Tutorial)).
- **Tool-calling**: Models can call custom functions during chat — web search, page fetching, math, and more. Each tool is a single `.py` file, easy to create and extend ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Tool-Calling-Tutorial)).
- **OpenAI-compatible API**: Chat and Completions endpoints with tool-calling support. Use as a local drop-in replacement for the OpenAI API ([examples](https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples)).
- **Training**: Fine-tune LoRAs on multi-turn chat or raw text datasets. Supports resuming interrupted runs ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/05-%E2%80%90-Training-Tab)).
- **Image generation**: A dedicated tab for `diffusers` models like **Z-Image-Turbo**. Features 4-bit/8-bit quantization and a persistent gallery with metadata ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Image-Generation-Tutorial)).
- **Easy setup**: [Portable builds](https://github.com/oobabooga/text-generation-webui/releases) (zero setup, just unzip and run) for GGUF models on Windows/Linux/macOS, or a one-click installer for the full feature set.
- 100% offline and private, with zero telemetry, external resources, or remote update requests. - 100% offline and private, with zero telemetry, external resources, or remote update requests.
- `instruct` mode for instruction-following (like ChatGPT), and `chat-instruct`/`chat` modes for talking to custom characters. Prompts are automatically formatted with Jinja2 templates. - **File attachments**: Upload text files, PDF documents, and .docx documents to talk about their contents.
- Edit messages, navigate between message versions, and branch conversations at any point. - **Vision (multimodal models)**: Attach images to messages for visual understanding ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Multimodal-Tutorial)).
- Free-form text generation in the Notebook tab without being limited to chat turns. - **Image generation**: A dedicated tab for `diffusers` models like **Z-Image-Turbo**. Features 4-bit/8-bit quantization and a persistent gallery with metadata ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Image-Generation-Tutorial)).
- Multiple sampling parameters and generation options for sophisticated text generation control. - **Web search**: Optionally search the internet with LLM-generated queries to add context to the conversation.
- Aesthetic UI with dark and light themes. - Aesthetic UI with dark and light themes.
- Syntax highlighting for code blocks and LaTeX rendering for mathematical expressions. - Syntax highlighting for code blocks and LaTeX rendering for mathematical expressions.
- `instruct` mode for instruction-following (like ChatGPT), and `chat-instruct`/`chat` modes for talking to custom characters.
- Automatic prompt formatting using Jinja2 templates. You don't need to ever worry about prompt formats.
- Edit messages, navigate between message versions, and branch conversations at any point.
- Multiple sampling parameters and generation options for sophisticated text generation control.
- Switch between different models in the UI without restarting.
- Automatic GPU layers for GGUF models (on NVIDIA GPUs).
- Free-form text generation in the Notebook tab without being limited to chat turns.
- OpenAI-compatible API with Chat and Completions endpoints, including tool-calling support see [examples](https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples).
- Extension support, with numerous built-in and user-contributed extensions available. See the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/07-%E2%80%90-Extensions) and [extensions directory](https://github.com/oobabooga/text-generation-webui-extensions) for details. - Extension support, with numerous built-in and user-contributed extensions available. See the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/07-%E2%80%90-Extensions) and [extensions directory](https://github.com/oobabooga/text-generation-webui-extensions) for details.
## How to install ## How to install
@ -46,10 +52,9 @@ A Gradio web UI for running Large Language Models locally. 100% private and offl
No installation needed just download, unzip and run. All dependencies included. No installation needed just download, unzip and run. All dependencies included.
Download from here: **https://github.com/oobabooga/text-generation-webui/releases** Compatible with GGUF (llama.cpp) models on Windows, Linux, and macOS.
- Builds are provided for Linux, Windows, and macOS, with options for CUDA, Vulkan, ROCm, and CPU-only. Download from here: **https://github.com/oobabooga/text-generation-webui/releases**
- Compatible with GGUF (llama.cpp) models.
#### Option 2: Manual portable install with venv #### Option 2: Manual portable install with venv
@ -81,7 +86,7 @@ deactivate
#### Option 3: One-click installer #### Option 3: One-click installer
For users who need additional backends (ExLlamaV3, Transformers), training, image generation, or extensions (TTS, voice input, translation, etc). Requires ~10GB disk space and downloads PyTorch. For users who need additional backends (ExLlamaV3, Transformers) or extensions (TTS, voice input, translation, etc). Requires ~10GB disk space and downloads PyTorch.
1. Clone the repository, or [download its source code](https://github.com/oobabooga/text-generation-webui/archive/refs/heads/main.zip) and extract it. 1. Clone the repository, or [download its source code](https://github.com/oobabooga/text-generation-webui/archive/refs/heads/main.zip) and extract it.
2. Run the startup script for your OS: `start_windows.bat`, `start_linux.sh`, or `start_macos.sh`. 2. Run the startup script for your OS: `start_windows.bat`, `start_linux.sh`, or `start_macos.sh`.
@ -136,7 +141,7 @@ For other platforms, download from: https://github.com/conda-forge/miniforge/rel
#### 1. Create a new conda environment #### 1. Create a new conda environment
``` ```
conda create -n textgen python=3.13 conda create -n textgen python=3.11
conda activate textgen conda activate textgen
``` ```
@ -144,12 +149,12 @@ conda activate textgen
| System | GPU | Command | | System | GPU | Command |
|--------|---------|---------| |--------|---------|---------|
| Linux/WSL | NVIDIA | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128` | | Linux/WSL | NVIDIA | `pip3 install torch==2.7.1 --index-url https://download.pytorch.org/whl/cu128` |
| Linux/WSL | CPU only | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cpu` | | Linux/WSL | CPU only | `pip3 install torch==2.7.1 --index-url https://download.pytorch.org/whl/cpu` |
| Linux | AMD | `pip3 install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.9.1%2Brocm7.2.0.lw.git7e1940d4-cp313-cp313-linux_x86_64.whl` | | Linux | AMD | `pip3 install torch==2.7.1 --index-url https://download.pytorch.org/whl/rocm6.2.4` |
| MacOS + MPS | Any | `pip3 install torch==2.9.1` | | MacOS + MPS | Any | `pip3 install torch==2.7.1` |
| Windows | NVIDIA | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128` | | Windows | NVIDIA | `pip3 install torch==2.7.1 --index-url https://download.pytorch.org/whl/cu128` |
| Windows | CPU only | `pip3 install torch==2.9.1` | | Windows | CPU only | `pip3 install torch==2.7.1` |
The up-to-date commands can be found here: https://pytorch.org/get-started/locally/. The up-to-date commands can be found here: https://pytorch.org/get-started/locally/.
@ -169,13 +174,16 @@ pip install -r requirements/full/<requirements file according to table below>
Requirements file to use: Requirements file to use:
| GPU | requirements file to use | | GPU | CPU | requirements file to use |
|--------|---------| |--------|---------|---------|
| NVIDIA | `requirements.txt` | | NVIDIA | has AVX2 | `requirements.txt` |
| AMD | `requirements_amd.txt` | | NVIDIA | no AVX2 | `requirements_noavx2.txt` |
| CPU only | `requirements_cpu_only.txt` | | AMD | has AVX2 | `requirements_amd.txt` |
| Apple Intel | `requirements_apple_intel.txt` | | AMD | no AVX2 | `requirements_amd_noavx2.txt` |
| Apple Silicon | `requirements_apple_silicon.txt` | | CPU only | has AVX2 | `requirements_cpu_only.txt` |
| CPU only | no AVX2 | `requirements_cpu_only_noavx2.txt` |
| Apple | Intel | `requirements_apple_intel.txt` |
| Apple | Apple Silicon | `requirements_apple_silicon.txt` |
### Start the web UI ### Start the web UI
@ -201,7 +209,7 @@ ln -s docker/{nvidia/Dockerfile,nvidia/docker-compose.yml,.dockerignore} .
For AMD GPU: For AMD GPU:
ln -s docker/{amd/Dockerfile,amd/docker-compose.yml,.dockerignore} . ln -s docker/{amd/Dockerfile,amd/docker-compose.yml,.dockerignore} .
For Intel GPU: For Intel GPU:
ln -s docker/{intel/Dockerfile,intel/docker-compose.yml,.dockerignore} . ln -s docker/{intel/Dockerfile,amd/docker-compose.yml,.dockerignore} .
For CPU only For CPU only
ln -s docker/{cpu/Dockerfile,cpu/docker-compose.yml,.dockerignore} . ln -s docker/{cpu/Dockerfile,cpu/docker-compose.yml,.dockerignore} .
cp docker/.env.example .env cp docker/.env.example .env
@ -236,24 +244,17 @@ List of command-line flags
</summary> </summary>
```txt ```txt
usage: server.py [-h] [--user-data-dir USER_DATA_DIR] [--multi-user] [--model MODEL] [--lora LORA [LORA ...]] [--model-dir MODEL_DIR] [--lora-dir LORA_DIR] [--model-menu] [--settings SETTINGS] usage: server.py [-h] [--multi-user] [--model MODEL] [--lora LORA [LORA ...]] [--model-dir MODEL_DIR] [--lora-dir LORA_DIR] [--model-menu] [--settings SETTINGS]
[--extensions EXTENSIONS [EXTENSIONS ...]] [--verbose] [--idle-timeout IDLE_TIMEOUT] [--image-model IMAGE_MODEL] [--image-model-dir IMAGE_MODEL_DIR] [--image-dtype {bfloat16,float16}] [--extensions EXTENSIONS [EXTENSIONS ...]] [--verbose] [--idle-timeout IDLE_TIMEOUT] [--loader LOADER] [--ctx-size N] [--cache-type N] [--model-draft MODEL_DRAFT]
[--image-attn-backend {flash_attention_2,sdpa}] [--image-cpu-offload] [--image-compile] [--image-quant {none,bnb-8bit,bnb-4bit,torchao-int8wo,torchao-fp4,torchao-float8wo}] [--draft-max DRAFT_MAX] [--gpu-layers-draft GPU_LAYERS_DRAFT] [--device-draft DEVICE_DRAFT] [--ctx-size-draft CTX_SIZE_DRAFT] [--gpu-layers N] [--mmproj MMPROJ] [--streaming-llm]
[--loader LOADER] [--ctx-size N] [--cache-type N] [--model-draft MODEL_DRAFT] [--draft-max DRAFT_MAX] [--gpu-layers-draft GPU_LAYERS_DRAFT] [--device-draft DEVICE_DRAFT] [--tensor-split TENSOR_SPLIT] [--row-split] [--no-mmap] [--mlock] [--no-kv-offload] [--batch-size BATCH_SIZE] [--threads THREADS] [--threads-batch THREADS_BATCH] [--numa]
[--ctx-size-draft CTX_SIZE_DRAFT] [--spec-type {none,ngram-mod,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-cache}] [--spec-ngram-size-n SPEC_NGRAM_SIZE_N] [--extra-flags EXTRA_FLAGS] [--cpu] [--cpu-memory CPU_MEMORY] [--disk] [--disk-cache-dir DISK_CACHE_DIR] [--load-in-8bit] [--bf16] [--no-cache] [--trust-remote-code]
[--spec-ngram-size-m SPEC_NGRAM_SIZE_M] [--spec-ngram-min-hits SPEC_NGRAM_MIN_HITS] [--gpu-layers N] [--cpu-moe] [--mmproj MMPROJ] [--streaming-llm] [--tensor-split TENSOR_SPLIT] [--force-safetensors] [--no_use_fast] [--attn-implementation IMPLEMENTATION] [--load-in-4bit] [--use_double_quant] [--compute_dtype COMPUTE_DTYPE] [--quant_type QUANT_TYPE]
[--row-split] [--no-mmap] [--mlock] [--no-kv-offload] [--batch-size BATCH_SIZE] [--ubatch-size UBATCH_SIZE] [--threads THREADS] [--threads-batch THREADS_BATCH] [--numa] [--enable-tp] [--tp-backend TP_BACKEND] [--gpu-split GPU_SPLIT] [--autosplit] [--cfg-cache] [--no_flash_attn] [--no_xformers] [--no_sdpa] [--num_experts_per_token N] [--cpp-runner]
[--parallel PARALLEL] [--fit-target FIT_TARGET] [--extra-flags EXTRA_FLAGS] [--cpu] [--cpu-memory CPU_MEMORY] [--disk] [--disk-cache-dir DISK_CACHE_DIR] [--load-in-8bit] [--bf16] [--deepspeed] [--nvme-offload-dir NVME_OFFLOAD_DIR] [--local_rank LOCAL_RANK] [--alpha_value ALPHA_VALUE] [--rope_freq_base ROPE_FREQ_BASE] [--compress_pos_emb COMPRESS_POS_EMB]
[--no-cache] [--trust-remote-code] [--force-safetensors] [--no_use_fast] [--attn-implementation IMPLEMENTATION] [--load-in-4bit] [--use_double_quant] [--compute_dtype COMPUTE_DTYPE] [--listen] [--listen-port LISTEN_PORT] [--listen-host LISTEN_HOST] [--share] [--auto-launch] [--gradio-auth GRADIO_AUTH] [--gradio-auth-path GRADIO_AUTH_PATH]
[--quant_type QUANT_TYPE] [--gpu-split GPU_SPLIT] [--enable-tp] [--tp-backend TP_BACKEND] [--cfg-cache] [--listen] [--listen-port LISTEN_PORT] [--listen-host LISTEN_HOST] [--share] [--ssl-keyfile SSL_KEYFILE] [--ssl-certfile SSL_CERTFILE] [--subpath SUBPATH] [--old-colors] [--portable] [--api] [--public-api] [--public-api-id PUBLIC_API_ID] [--api-port API_PORT]
[--auto-launch] [--gradio-auth GRADIO_AUTH] [--gradio-auth-path GRADIO_AUTH_PATH] [--ssl-keyfile SSL_KEYFILE] [--ssl-certfile SSL_CERTFILE] [--subpath SUBPATH] [--old-colors] [--api-key API_KEY] [--admin-key ADMIN_KEY] [--api-enable-ipv6] [--api-disable-ipv4] [--nowebui]
[--portable] [--api] [--public-api] [--public-api-id PUBLIC_API_ID] [--api-port API_PORT] [--api-key API_KEY] [--admin-key ADMIN_KEY] [--api-enable-ipv6] [--api-disable-ipv4]
[--nowebui] [--temperature N] [--dynatemp-low N] [--dynatemp-high N] [--dynatemp-exponent N] [--smoothing-factor N] [--smoothing-curve N] [--min-p N] [--top-p N] [--top-k N]
[--typical-p N] [--xtc-threshold N] [--xtc-probability N] [--epsilon-cutoff N] [--eta-cutoff N] [--tfs N] [--top-a N] [--top-n-sigma N] [--adaptive-target N] [--adaptive-decay N]
[--dry-multiplier N] [--dry-allowed-length N] [--dry-base N] [--repetition-penalty N] [--frequency-penalty N] [--presence-penalty N] [--encoder-repetition-penalty N]
[--no-repeat-ngram-size N] [--repetition-penalty-range N] [--penalty-alpha N] [--guidance-scale N] [--mirostat-mode N] [--mirostat-tau N] [--mirostat-eta N]
[--do-sample | --no-do-sample] [--dynamic-temperature | --no-dynamic-temperature] [--temperature-last | --no-temperature-last] [--sampler-priority N] [--dry-sequence-breakers N]
[--enable-thinking | --no-enable-thinking] [--reasoning-effort N] [--chat-template-file CHAT_TEMPLATE_FILE]
Text Generation Web UI Text Generation Web UI
@ -261,8 +262,7 @@ options:
-h, --help show this help message and exit -h, --help show this help message and exit
Basic settings: Basic settings:
--user-data-dir USER_DATA_DIR Path to the user data directory. Default: auto-detected. --multi-user Multi-user mode. Chat histories are not saved or automatically loaded. Warning: this is likely not safe for sharing publicly.
--multi-user Multi-user mode. Chat histories are not saved or automatically loaded. Best suited for small trusted teams.
--model MODEL Name of the model to load by default. --model MODEL Name of the model to load by default.
--lora LORA [LORA ...] The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces. --lora LORA [LORA ...] The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.
--model-dir MODEL_DIR Path to directory with all the models. --model-dir MODEL_DIR Path to directory with all the models.
@ -274,23 +274,14 @@ Basic settings:
--verbose Print the prompts to the terminal. --verbose Print the prompts to the terminal.
--idle-timeout IDLE_TIMEOUT Unload model after this many minutes of inactivity. It will be automatically reloaded when you try to use it again. --idle-timeout IDLE_TIMEOUT Unload model after this many minutes of inactivity. It will be automatically reloaded when you try to use it again.
Image model:
--image-model IMAGE_MODEL Name of the image model to select on startup (overrides saved setting).
--image-model-dir IMAGE_MODEL_DIR Path to directory with all the image models.
--image-dtype {bfloat16,float16} Data type for image model.
--image-attn-backend {flash_attention_2,sdpa} Attention backend for image model.
--image-cpu-offload Enable CPU offloading for image model.
--image-compile Compile the image model for faster inference.
--image-quant {none,bnb-8bit,bnb-4bit,torchao-int8wo,torchao-fp4,torchao-float8wo}
Quantization method for image model.
Model loader: Model loader:
--loader LOADER Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav3, TensorRT- --loader LOADER Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav2_HF, ExLlamav2,
LLM. TensorRT-LLM.
Context and cache: Context and cache:
--ctx-size, --n_ctx, --max_seq_len N Context size in tokens. 0 = auto for llama.cpp (requires gpu-layers=-1), 8192 for other loaders. --ctx-size N, --n_ctx N, --max_seq_len N Context size in tokens.
--cache-type, --cache_type N KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8). --cache-type N, --cache_type N KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits
separately, e.g. q4_q8).
Speculative decoding: Speculative decoding:
--model-draft MODEL_DRAFT Path to the draft model for speculative decoding. --model-draft MODEL_DRAFT Path to the draft model for speculative decoding.
@ -298,15 +289,9 @@ Speculative decoding:
--gpu-layers-draft GPU_LAYERS_DRAFT Number of layers to offload to the GPU for the draft model. --gpu-layers-draft GPU_LAYERS_DRAFT Number of layers to offload to the GPU for the draft model.
--device-draft DEVICE_DRAFT Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1 --device-draft DEVICE_DRAFT Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1
--ctx-size-draft CTX_SIZE_DRAFT Size of the prompt context for the draft model. If 0, uses the same as the main model. --ctx-size-draft CTX_SIZE_DRAFT Size of the prompt context for the draft model. If 0, uses the same as the main model.
--spec-type {none,ngram-mod,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-cache}
Draftless speculative decoding type. Recommended: ngram-mod.
--spec-ngram-size-n SPEC_NGRAM_SIZE_N N-gram lookup size for ngram speculative decoding.
--spec-ngram-size-m SPEC_NGRAM_SIZE_M Draft n-gram size for ngram speculative decoding.
--spec-ngram-min-hits SPEC_NGRAM_MIN_HITS Minimum n-gram hits for ngram-map speculative decoding.
llama.cpp: llama.cpp:
--gpu-layers, --n-gpu-layers N Number of layers to offload to the GPU. -1 = auto. --gpu-layers N, --n-gpu-layers N Number of layers to offload to the GPU.
--cpu-moe Move the experts to the CPU (for MoE models).
--mmproj MMPROJ Path to the mmproj file for vision models. --mmproj MMPROJ Path to the mmproj file for vision models.
--streaming-llm Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed. --streaming-llm Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.
--tensor-split TENSOR_SPLIT Split the model across multiple GPUs. Comma-separated list of proportions. Example: 60,40. --tensor-split TENSOR_SPLIT Split the model across multiple GPUs. Comma-separated list of proportions. Example: 60,40.
@ -314,22 +299,17 @@ llama.cpp:
--no-mmap Prevent mmap from being used. --no-mmap Prevent mmap from being used.
--mlock Force the system to keep the model in RAM. --mlock Force the system to keep the model in RAM.
--no-kv-offload Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance. --no-kv-offload Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.
--batch-size BATCH_SIZE Maximum number of prompt tokens to batch together when calling llama-server. This is the application level batch size. --batch-size BATCH_SIZE Maximum number of prompt tokens to batch together when calling llama_eval.
--ubatch-size UBATCH_SIZE Maximum number of prompt tokens to batch together when calling llama-server. This is the max physical batch size for computation (device level).
--threads THREADS Number of threads to use. --threads THREADS Number of threads to use.
--threads-batch THREADS_BATCH Number of threads to use for batches/prompt processing. --threads-batch THREADS_BATCH Number of threads to use for batches/prompt processing.
--numa Activate NUMA task allocation for llama.cpp. --numa Activate NUMA task allocation for llama.cpp.
--parallel PARALLEL Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set
ctx_size to 32768.
--fit-target FIT_TARGET Target VRAM margin per device for auto GPU layers, comma-separated list of values in MiB. A single value is broadcast across all devices.
Default: 1024.
--extra-flags EXTRA_FLAGS Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU" --extra-flags EXTRA_FLAGS Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"
Transformers/Accelerate: Transformers/Accelerate:
--cpu Use the CPU to generate text. Warning: Training on CPU is extremely slow. --cpu Use the CPU to generate text. Warning: Training on CPU is extremely slow.
--cpu-memory CPU_MEMORY Maximum CPU memory in GiB. Use this for CPU offloading. --cpu-memory CPU_MEMORY Maximum CPU memory in GiB. Use this for CPU offloading.
--disk If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. --disk If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.
--disk-cache-dir DISK_CACHE_DIR Directory to save the disk cache to. --disk-cache-dir DISK_CACHE_DIR Directory to save the disk cache to. Defaults to "user_data/cache".
--load-in-8bit Load the model with 8-bit precision (using bitsandbytes). --load-in-8bit Load the model with 8-bit precision (using bitsandbytes).
--bf16 Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. --bf16 Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
--no-cache Set use_cache to False while generating text. This reduces VRAM usage slightly, but it comes at a performance cost. --no-cache Set use_cache to False while generating text. This reduces VRAM usage slightly, but it comes at a performance cost.
@ -345,10 +325,30 @@ bitsandbytes 4-bit:
--quant_type QUANT_TYPE quant_type for 4-bit. Valid options: nf4, fp4. --quant_type QUANT_TYPE quant_type for 4-bit. Valid options: nf4, fp4.
ExLlamaV3: ExLlamaV3:
--gpu-split GPU_SPLIT Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.
--enable-tp, --enable_tp Enable Tensor Parallelism (TP) to split the model across GPUs. --enable-tp, --enable_tp Enable Tensor Parallelism (TP) to split the model across GPUs.
--tp-backend TP_BACKEND The backend for tensor parallelism. Valid options: native, nccl. Default: native. --tp-backend TP_BACKEND The backend for tensor parallelism. Valid options: native, nccl. Default: native.
--cfg-cache Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.
ExLlamaV2:
--gpu-split GPU_SPLIT Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.
--autosplit Autosplit the model tensors across the available GPUs. This causes --gpu-split to be ignored.
--cfg-cache ExLlamav2_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.
--no_flash_attn Force flash-attention to not be used.
--no_xformers Force xformers to not be used.
--no_sdpa Force Torch SDPA to not be used.
--num_experts_per_token N Number of experts to use for generation. Applies to MoE models like Mixtral.
TensorRT-LLM:
--cpp-runner Use the ModelRunnerCpp runner, which is faster than the default ModelRunner but doesn't support streaming yet.
DeepSpeed:
--deepspeed Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.
--nvme-offload-dir NVME_OFFLOAD_DIR DeepSpeed: Directory to use for ZeRO-3 NVME offloading.
--local_rank LOCAL_RANK DeepSpeed: Optional argument for distributed setups.
RoPE:
--alpha_value ALPHA_VALUE Positional embeddings alpha factor for NTK RoPE scaling. Use either this or compress_pos_emb, not both.
--rope_freq_base ROPE_FREQ_BASE If greater than 0, will be used instead of alpha_value. Those two are related by rope_freq_base = 10000 * alpha_value ^ (64 / 63).
--compress_pos_emb COMPRESS_POS_EMB Positional embeddings compression factor. Should be set to (context length) / (model's original context length). Equal to 1/rope_freq_scale.
Gradio: Gradio:
--listen Make the web UI reachable from your local network. --listen Make the web UI reachable from your local network.
@ -366,7 +366,7 @@ Gradio:
API: API:
--api Enable the API extension. --api Enable the API extension.
--public-api Create a public URL for the API using Cloudflare. --public-api Create a public URL for the API using Cloudfare.
--public-api-id PUBLIC_API_ID Tunnel ID for named Cloudflare Tunnel. Use together with public-api option. --public-api-id PUBLIC_API_ID Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.
--api-port API_PORT The listening port for the API. --api-port API_PORT The listening port for the API.
--api-key API_KEY API authentication key. --api-key API_KEY API authentication key.
@ -374,88 +374,65 @@ API:
--api-enable-ipv6 Enable IPv6 for the API --api-enable-ipv6 Enable IPv6 for the API
--api-disable-ipv4 Disable IPv4 for the API --api-disable-ipv4 Disable IPv4 for the API
--nowebui Do not launch the Gradio UI. Useful for launching the API in standalone mode. --nowebui Do not launch the Gradio UI. Useful for launching the API in standalone mode.
API generation defaults:
--temperature N Temperature
--dynatemp-low N Dynamic temperature low
--dynatemp-high N Dynamic temperature high
--dynatemp-exponent N Dynamic temperature exponent
--smoothing-factor N Smoothing factor
--smoothing-curve N Smoothing curve
--min-p N Min P
--top-p N Top P
--top-k N Top K
--typical-p N Typical P
--xtc-threshold N XTC threshold
--xtc-probability N XTC probability
--epsilon-cutoff N Epsilon cutoff
--eta-cutoff N Eta cutoff
--tfs N TFS
--top-a N Top A
--top-n-sigma N Top N Sigma
--adaptive-target N Adaptive target
--adaptive-decay N Adaptive decay
--dry-multiplier N DRY multiplier
--dry-allowed-length N DRY allowed length
--dry-base N DRY base
--repetition-penalty N Repetition penalty
--frequency-penalty N Frequency penalty
--presence-penalty N Presence penalty
--encoder-repetition-penalty N Encoder repetition penalty
--no-repeat-ngram-size N No repeat ngram size
--repetition-penalty-range N Repetition penalty range
--penalty-alpha N Penalty alpha
--guidance-scale N Guidance scale
--mirostat-mode N Mirostat mode
--mirostat-tau N Mirostat tau
--mirostat-eta N Mirostat eta
--do-sample, --no-do-sample Do sample
--dynamic-temperature, --no-dynamic-temperature Dynamic temperature
--temperature-last, --no-temperature-last Temperature last
--sampler-priority N Sampler priority
--dry-sequence-breakers N DRY sequence breakers
--enable-thinking, --no-enable-thinking Enable thinking
--reasoning-effort N Reasoning effort
--chat-template-file CHAT_TEMPLATE_FILE Path to a chat template file (.jinja, .jinja2, or .yaml) to use as the default instruction template for API requests. Overrides the model's
built-in template.
``` ```
</details> </details>
## Downloading models ## Downloading models
1. Download a GGUF model file from [Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads&search=gguf). Models should be placed in the folder `text-generation-webui/user_data/models`. They are usually downloaded from [Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads&search=gguf).
2. Place it in the `user_data/models` folder.
That's it. The UI will detect it automatically. To check if a GGUF model will fit in your hardware before downloading it, you can use this tool I created:
To check what will fit your GPU, you can use the [VRAM Calculator](https://huggingface.co/spaces/oobabooga/accurate-gguf-vram-calculator). [Accurate GGUF VRAM Calculator](https://huggingface.co/spaces/oobabooga/accurate-gguf-vram-calculator)
<details> * GGUF models are a single file and should be placed directly into `user_data/models`. Example:
<summary>Other model types (Transformers, EXL3)</summary>
Models that consist of multiple files (like 16-bit Transformers models and EXL3 models) should be placed in a subfolder inside `user_data/models`:
``` ```
text-generation-webui text-generation-webui
└── user_data └── user_data
└── models └── models
└── Qwen_Qwen3-8B └── llama-2-13b-chat.Q4_K_M.gguf
├── config.json
├── generation_config.json
├── model-00001-of-00004.safetensors
├── ...
├── tokenizer_config.json
└── tokenizer.json
``` ```
These formats require the one-click installer (not the portable build). * The remaining model types (like 16-bit Transformers models and EXL3 models) are made of several files and must be placed in a subfolder. Example:
</details>
```
text-generation-webui
└── user_data
└── models
└── lmsys_vicuna-33b-v1.3
├── config.json
├── generation_config.json
├── pytorch_model-00001-of-00007.bin
├── pytorch_model-00002-of-00007.bin
├── pytorch_model-00003-of-00007.bin
├── pytorch_model-00004-of-00007.bin
├── pytorch_model-00005-of-00007.bin
├── pytorch_model-00006-of-00007.bin
├── pytorch_model-00007-of-00007.bin
├── pytorch_model.bin.index.json
├── special_tokens_map.json
├── tokenizer_config.json
└── tokenizer.model
```
In both cases, you can use the "Model" tab of the UI to download the model from Hugging Face automatically. It is also possible to download it via the command-line with:
```
python download-model.py organization/model
```
Run `python download-model.py --help` to see all the options.
## Documentation ## Documentation
https://github.com/oobabooga/text-generation-webui/wiki https://github.com/oobabooga/text-generation-webui/wiki
## Google Colab notebook
https://colab.research.google.com/github/oobabooga/text-generation-webui/blob/main/Colab-TextGen-GPU.ipynb
## Community ## Community
https://www.reddit.com/r/Oobabooga/ https://www.reddit.com/r/Oobabooga/

View file

@ -21,7 +21,6 @@ set INSTALL_ENV_DIR=%cd%\installer_files\env
set PYTHONNOUSERSITE=1 set PYTHONNOUSERSITE=1
set PYTHONPATH= set PYTHONPATH=
set PYTHONHOME= set PYTHONHOME=
set PYTHONUTF8=1
set "CUDA_PATH=%INSTALL_ENV_DIR%" set "CUDA_PATH=%INSTALL_ENV_DIR%"
set "CUDA_HOME=%CUDA_PATH%" set "CUDA_HOME=%CUDA_PATH%"

View file

@ -2,7 +2,6 @@
display: grid; display: grid;
align-items: start; align-items: start;
grid-template-columns: 60px minmax(0, 1fr); grid-template-columns: 60px minmax(0, 1fr);
width: min(100%, calc(724px + 60px));
padding-bottom: 22px; padding-bottom: 22px;
padding-top: 6px; padding-top: 6px;
font-size: 18px; font-size: 18px;
@ -92,6 +91,9 @@
} }
.message-body p { .message-body p {
margin-bottom: 0 !important;
font-size: 16px !important;
line-height: 1.5 !important;
color: #e0e0e0 !important; /* Light color for text */ color: #e0e0e0 !important; /* Light color for text */
} }
@ -120,7 +122,7 @@
} }
.message-body p { .message-body p {
font-size: 14px !important; font-size: 14px !important; /* Smaller text for mobile */
} }
.username { .username {

View file

@ -4,7 +4,6 @@
display: grid; display: grid;
align-items: start; align-items: start;
grid-template-columns: 60px minmax(0, 1fr); grid-template-columns: 60px minmax(0, 1fr);
width: min(100%, calc(724px + 60px + 90px));
padding-bottom: 21px; padding-bottom: 21px;
padding-top: 7px; padding-top: 7px;
font-size: 18px; font-size: 18px;
@ -87,8 +86,10 @@
border-radius: 20px; border-radius: 20px;
} }
.message-body p, .message-body li { .message-body p {
margin-bottom: 0 !important;
font-size: 18px !important; font-size: 18px !important;
line-height: 1.428571429 !important;
color: rgb(243 244 246) !important; color: rgb(243 244 246) !important;
text-shadow: 2px 2px 2px rgb(0 0 0); text-shadow: 2px 2px 2px rgb(0 0 0);
font-weight: 500; font-weight: 500;
@ -126,7 +127,7 @@
padding-left: 0; padding-left: 0;
} }
.message-body p, .message-body li { .message-body p {
font-size: 16px !important; font-size: 16px !important;
} }

View file

@ -19,5 +19,4 @@
padding-bottom: 1.5em; padding-bottom: 1.5em;
padding-top: 0.5em; padding-top: 0.5em;
grid-template-columns: 70px minmax(0, 1fr); grid-template-columns: 70px minmax(0, 1fr);
width: min(100%, calc(724px + 70px));
} }

View file

@ -2,7 +2,6 @@
display: grid; display: grid;
align-items: start; align-items: start;
grid-template-columns: 60px minmax(0, 1fr); grid-template-columns: 60px minmax(0, 1fr);
width: min(100%, calc(724px + 60px));
padding-bottom: 1.5em; padding-bottom: 1.5em;
padding-top: 0.5em; padding-top: 0.5em;
font-size: 15px; font-size: 15px;
@ -47,10 +46,16 @@
border-radius: 20px; border-radius: 20px;
} }
.message-body p, .message-body li { .message-body p {
font-size: 15px !important;
line-height: 22.5px !important;
font-weight: 500; font-weight: 500;
} }
.message-body p, .chat .message-body ul, .chat .message-body ol {
margin-bottom: 10px !important;
}
.dark .message-body p em { .dark .message-body p em {
color: rgb(138 138 138) !important; color: rgb(138 138 138) !important;
} }

View file

@ -1,5 +1,4 @@
.message { .message {
width: min(100%, calc(724px + 60px));
padding-bottom: 22px; padding-bottom: 22px;
padding-top: 3px; padding-top: 3px;
font-size: 15px; font-size: 15px;
@ -61,10 +60,8 @@
text-align: right; text-align: right;
} }
.dark .circle-bot + .text div, .dark .circle-bot + .text *, .dark .circle-bot + .text div, .dark .circle-bot + .text * {
.dark .chat .message .circle-bot + .text .message-body :is(h1, h2, h3, h4, h5, h6), color: #000;
.dark .chat .message .circle-bot + .text .message-body a {
color: #000 !important;
} }
.text { .text {
@ -79,14 +76,19 @@
font-weight: bold; font-weight: bold;
} }
.message-body {
}
.message-body img { .message-body img {
max-width: 300px; max-width: 300px;
max-height: 300px; max-height: 300px;
border-radius: 20px; border-radius: 20px;
} }
.message-body p, .message-body li { .message-body p {
margin-bottom: 0 !important;
font-size: 15px !important; font-size: 15px !important;
line-height: 1.428571429 !important;
font-weight: 500; font-weight: 500;
} }

View file

@ -1,6 +1,5 @@
.message { .message {
display: block; display: block;
width: min(100%, 724px);
padding-top: 0; padding-top: 0;
padding-bottom: 21px; padding-bottom: 21px;
font-size: 15px; font-size: 15px;
@ -78,8 +77,14 @@
border-radius: 12px; border-radius: 12px;
} }
.message-body p, .message-body li { .message-body p {
font-size: 15px !important; font-size: 15px !important;
line-height: 1.4 !important;
font-weight: 400;
}
.message-body p:first-child {
margin-top: 0 !important;
} }
.dark .message-body p em { .dark .message-body p em {
@ -95,3 +100,6 @@
margin-top: 8px; margin-top: 8px;
} }
.message-body p, .chat .message-body ul, .chat .message-body ol {
margin-bottom: 10px !important;
}

View file

@ -19,14 +19,12 @@
color: #d1d5db !important; color: #d1d5db !important;
} }
.chat .message-body :is(th, td), .chat .message-body :is(th, td) {
.prose hr {
border-color: #40404096 !important; border-color: #40404096 !important;
} }
.dark .chat .message-body :is(th, td), .dark .chat .message-body :is(th, td) {
.dark .prose hr { border-color: #ffffff75 !important;
border-color: rgb(255 255 255 / 30%) !important;
} }
.chat .message-body :is(p, ul, ol) { .chat .message-body :is(p, ul, ol) {
@ -78,7 +76,7 @@
.chat .user-message .text, .chat .user-message .text,
.chat .assistant-message .text { .chat .assistant-message .text {
max-width: 724px; max-width: 700px;
margin-left: auto; margin-left: auto;
margin-right: auto; margin-right: auto;
} }

View file

@ -400,6 +400,7 @@ audio {
} }
.chat .message { .chat .message {
width: min(100%, 48rem);
margin-left: auto; margin-left: auto;
margin-right: auto; margin-right: auto;
text-align: start; text-align: start;
@ -430,19 +431,10 @@ audio {
font-size: 16px; font-size: 16px;
} }
.dark .message-body h1, .dark .message-body :is(h1, h2, h3, h4, h5, h6) {
.dark .message-body h2,
.dark .message-body h3,
.dark .message-body h4,
.dark .message-body h5,
.dark .message-body h6 {
color: white !important; color: white !important;
} }
.dark .message-body blockquote {
border-left-color: rgb(255 255 255 / 30%);
}
.message-body h1 { .message-body h1 {
font-weight: 800; font-weight: 800;
font-size: 2.25em; font-size: 2.25em;
@ -723,7 +715,7 @@ audio {
.hover-menu { .hover-menu {
display: none; display: none;
position: absolute; position: absolute;
bottom: 100%; bottom: 80%;
left: 0; left: 0;
box-shadow: 0 0 5px rgb(0 0 0 / 25%); box-shadow: 0 0 5px rgb(0 0 0 / 25%);
z-index: 10000; z-index: 10000;
@ -839,20 +831,9 @@ audio {
} }
} }
.message-body p, .message-body li { .message-body ol, .message-body ul {
line-height: 1.75 !important;
}
.message-body p, .message-body ul, .message-body ol {
margin: 1.25em 0 !important;
}
.message-body :is(p, ul, ol):first-child {
margin-top: 0 !important; margin-top: 0 !important;
} margin-bottom: 1.25em !important;
.message-body :is(p, ul, ol):last-child {
margin-bottom: 0 !important;
} }
/* ---------------------------------------------- /* ----------------------------------------------
@ -1022,49 +1003,6 @@ audio {
padding-right: 0.5rem; padding-right: 0.5rem;
} }
#new-chat-wrapper {
display: contents;
}
.new-chat-arrow {
cursor: pointer;
position: relative;
padding: 0;
margin-right: -15px;
height: 39.594px;
display: flex;
align-items: center;
}
.new-chat-menu {
display: none;
position: absolute;
top: 0;
left: 0;
padding-top: 1.2em;
z-index: var(--layer-top);
white-space: nowrap;
}
.new-chat-arrow:hover .new-chat-menu {
display: block;
}
.new-chat-menu-item {
cursor: pointer;
padding: var(--size-2);
background: var(--background-fill-primary);
box-shadow: var(--shadow-drop-lg);
border-radius: var(--container-radius);
color: var(--body-text-color);
font-size: var(--text-md);
font-weight: var(--button-large-text-weight);
}
.new-chat-menu-item:hover {
background: var(--background-fill-secondary);
}
#past-chats-row, #past-chats-row,
#chat-controls { #chat-controls {
width: 260px; width: 260px;
@ -1435,6 +1373,7 @@ audio {
overflow-wrap: break-word; overflow-wrap: break-word;
max-height: 250px; max-height: 250px;
overflow-y: scroll; overflow-y: scroll;
contain: layout;
} }
.chat .message-body .thinking-content p, .chat .message-body .thinking-content p,
@ -1706,7 +1645,7 @@ button:focus {
} }
#user-description textarea { #user-description textarea {
height: calc(100vh - 334px) !important; height: calc(100vh - 231px) !important;
min-height: 90px !important; min-height: 90px !important;
} }
@ -1723,7 +1662,7 @@ button:focus {
.chat-parent { .chat-parent {
/* Optimize for scrolling performance */ /* Optimize for scrolling performance */
will-change: scroll-position; will-change: scroll-position;
contain: style paint; contain: layout style paint;
/* Ensure GPU acceleration */ /* Ensure GPU acceleration */
transform: translateZ(0); transform: translateZ(0);
@ -1858,112 +1797,3 @@ button#swap-height-width {
top: 0; top: 0;
left: calc(100% - 174px); left: calc(100% - 174px);
} }
table {
border-collapse: collapse;
}
.table-wrapper {
overflow-x: auto;
}
.message-body :is(td, th) {
word-break: normal;
overflow-wrap: normal;
}
table, tr, td, th, thead {
border: 0;
}
td + td,
th + th { border-left: 1px solid; }
tr + tr td,
tr + tr th { border-top: 1px solid; }
thead + tbody tr:first-child td,
thead + tbody tr:first-child th { border-top: 1px solid; }
/* ------------------------------------------------
Tools CheckboxGroup - vertical DragDrop-like style
------------------------------------------------ */
/* "Refresh list" link in the Tools label */
.tools-refresh-link {
cursor: pointer;
}
/* Checkbox list container */
#tools-group {
padding: 0 !important;
border-width: 0 !important;
background: transparent !important;
min-height: 0 !important;
}
#tools-group .wrap {
display: flex;
flex-direction: column;
flex-wrap: nowrap;
gap: 4px;
padding: 0;
margin-top: var(--spacing-lg);
max-height: 350px;
overflow-y: auto;
}
/* Pretty scrollbar for the tools list */
#tools-group .wrap::-webkit-scrollbar {
width: 8px;
height: 8px;
}
#tools-group .wrap::-webkit-scrollbar-track {
background: transparent;
}
#tools-group .wrap::-webkit-scrollbar-thumb,
#tools-group .wrap::-webkit-scrollbar-thumb:hover {
background: var(--neutral-300);
border-radius: 30px;
}
.dark #tools-group .wrap::-webkit-scrollbar-thumb,
.dark #tools-group .wrap::-webkit-scrollbar-thumb:hover {
background: rgb(255 255 255 / 6.25%);
border-radius: 10px;
}
#tools-group .wrap::-webkit-scrollbar-corner {
background: transparent;
}
/* Each checkbox item */
#tools-group label {
display: flex;
align-items: center;
gap: 8px;
padding: 5px 8px;
border-radius: var(--radius-sm, 4px);
background: var(--block-background-fill);
border: 1px solid var(--border-color-primary);
color: var(--body-text-color);
font-size: var(--input-text-size);
font-weight: var(--input-text-weight);
cursor: pointer;
user-select: none;
transition: border-color 0.15s ease, background 0.15s ease;
box-shadow: none;
}
#tools-group label:hover {
border-color: var(--input-border-color-focus);
}
#tools-group label span {
flex: 1;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}

View file

@ -1,3 +1,8 @@
.env .env
Dockerfile Dockerfile
/user_data /characters
/loras
/models
/presets
/prompts
/training

View file

@ -1,8 +1,8 @@
# specify which cuda arch version your card supports (NVIDIA only) # by default the Dockerfile specifies these versions: 3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX
# https://developer.nvidia.com/cuda-gpus # however for me to work i had to specify the exact version for my card ( 2060 ) it was 7.5
# or run: nvidia-smi --query-gpu=name,compute_cap --format=csv # https://developer.nvidia.com/cuda-gpus you can find the version for your card here
# default in docker-compose.yml covers RTX 3090 (8.6) and RTX 4090 (8.9) # Or for a programatic approach run `nvidia-smi --query-gpu=name,compute_cap --format=csv`
TORCH_CUDA_ARCH_LIST=8.6;8.9+PTX TORCH_CUDA_ARCH_LIST=7.5
# the port the webui binds to on the host # the port the webui binds to on the host
HOST_PORT=7860 HOST_PORT=7860
# the port the webui binds to inside the container # the port the webui binds to inside the container
@ -19,3 +19,6 @@ APP_RUNTIME_GID=6972
# override default app build permissions (handy for deploying to cloud) # override default app build permissions (handy for deploying to cloud)
#APP_GID=6972 #APP_GID=6972
#APP_UID=6972 #APP_UID=6972
# Set cache env
TRANSFORMERS_CACHE=/home/app/text-generation-webui/cache/
HF_HOME=/home/app/text-generation-webui/cache/

View file

@ -1,24 +1,27 @@
FROM nvidia/cuda:13.0.1-cudnn-runtime-ubuntu24.04 FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-runtime
# Install Python 3.12, Git, and OpenMPI # Install Git
RUN apt update && apt install -y python3.12 python3-pip git build-essential openmpi-bin libopenmpi-dev RUN apt update && apt install -y git
# System-wide TensorRT-LLM requirements
RUN apt install -y openmpi-bin libopenmpi-dev
# Set the working directory # Set the working directory
WORKDIR /app WORKDIR /app
# This is needed to avoid an error about "Failed to build mpi4py" in the next command
ENV LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
# Install text-generation-webui # Install text-generation-webui
RUN git clone https://github.com/oobabooga/text-generation-webui RUN git clone https://github.com/oobabooga/text-generation-webui
WORKDIR /app/text-generation-webui WORKDIR /app/text-generation-webui
RUN pip install --break-system-packages -r requirements/full/requirements.txt RUN pip install -r requirements.txt
# This is needed to avoid an error about "Failed to build mpi4py" in the next command
ENV LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
# Install TensorRT-LLM # Install TensorRT-LLM
RUN pip3 install --break-system-packages tensorrt_llm==1.1.0 --extra-index-url https://pypi.nvidia.com RUN pip3 install tensorrt_llm==0.10.0 -U --pre --extra-index-url https://pypi.nvidia.com
# Expose the necessary port for the Python server # Expose the necessary port for the Python server
EXPOSE 7860 5000 EXPOSE 7860 5000
# Run the Python server.py script with the specified command # Run the Python server.py script with the specified command
CMD ["python3", "server.py", "--api", "--listen"] CMD ["python", "server.py", "--api", "--listen"]

View file

@ -1,6 +1,7 @@
# BUILDER # BUILDER
FROM ubuntu:22.04 FROM ubuntu:22.04
WORKDIR /builder WORKDIR /builder
ARG TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX}"
ARG BUILD_EXTENSIONS="${BUILD_EXTENSIONS:-}" ARG BUILD_EXTENSIONS="${BUILD_EXTENSIONS:-}"
ARG APP_UID="${APP_UID:-6972}" ARG APP_UID="${APP_UID:-6972}"
ARG APP_GID="${APP_GID:-6972}" ARG APP_GID="${APP_GID:-6972}"
@ -13,7 +14,8 @@ WORKDIR /home/app/
RUN git clone https://github.com/oobabooga/text-generation-webui.git RUN git clone https://github.com/oobabooga/text-generation-webui.git
WORKDIR /home/app/text-generation-webui WORKDIR /home/app/text-generation-webui
RUN GPU_CHOICE=B LAUNCH_AFTER_INSTALL=FALSE INSTALL_EXTENSIONS=TRUE ./start_linux.sh --verbose RUN GPU_CHOICE=B LAUNCH_AFTER_INSTALL=FALSE INSTALL_EXTENSIONS=TRUE ./start_linux.sh --verbose
EXPOSE ${CONTAINER_PORT:-7860} ${CONTAINER_API_PORT:-5000} COPY /user_data/CMD_FLAGS.txt /home/app/text-generation-webui/user_data
EXPOSE ${CONTAINER_PORT:-7860} ${CONTAINER_API_PORT:-5000} ${CONTAINER_API_STREAM_PORT:-5005}
WORKDIR /home/app/text-generation-webui WORKDIR /home/app/text-generation-webui
# set umask to ensure group read / write at runtime # set umask to ensure group read / write at runtime
CMD umask 0002 && export HOME=/home/app/text-generation-webui && ./start_linux.sh --listen CMD umask 0002 && export HOME=/home/app/text-generation-webui && ./start_linux.sh

View file

@ -4,6 +4,22 @@ services:
build: build:
context: . context: .
args: args:
# Requirements file to use:
# | GPU | requirements file to use |
# |--------|---------|
# | NVIDIA | `requirements.txt` |
# | AMD | `requirements_amd.txt` |
# | CPU only | `requirements_cpu_only.txt` |
# | Apple Intel | `requirements_apple_intel.txt` |
# | Apple Silicon | `requirements_apple_silicon.txt` |
# Default: requirements.txt`
# BUILD_REQUIREMENTS: requirements.txt
# Extension requirements to build:
# BUILD_EXTENSIONS:
# specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus
TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST:-7.5}
BUILD_EXTENSIONS: ${BUILD_EXTENSIONS:-} BUILD_EXTENSIONS: ${BUILD_EXTENSIONS:-}
APP_GID: ${APP_GID:-6972} APP_GID: ${APP_GID:-6972}
APP_UID: ${APP_UID:-6972} APP_UID: ${APP_UID:-6972}

View file

@ -1,9 +1,14 @@
# BUILDER # BUILDER
FROM ubuntu:22.04 FROM ubuntu:22.04
WORKDIR /builder WORKDIR /builder
ARG TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX}"
ARG BUILD_EXTENSIONS="${BUILD_EXTENSIONS:-}" ARG BUILD_EXTENSIONS="${BUILD_EXTENSIONS:-}"
ARG APP_UID="${APP_UID:-6972}" ARG APP_UID="${APP_UID:-6972}"
ARG APP_GID="${APP_GID:-6972}" ARG APP_GID="${APP_GID:-6972}"
ARG GPU_CHOICE=A
ARG USE_CUDA118=FALSE
ARG LAUNCH_AFTER_INSTALL=FALSE
ARG INSTALL_EXTENSIONS=TRUE
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked,rw \ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked,rw \
apt update && \ apt update && \
@ -13,7 +18,8 @@ WORKDIR /home/app/
RUN git clone https://github.com/oobabooga/text-generation-webui.git RUN git clone https://github.com/oobabooga/text-generation-webui.git
WORKDIR /home/app/text-generation-webui WORKDIR /home/app/text-generation-webui
RUN GPU_CHOICE=N LAUNCH_AFTER_INSTALL=FALSE INSTALL_EXTENSIONS=TRUE ./start_linux.sh --verbose RUN GPU_CHOICE=N LAUNCH_AFTER_INSTALL=FALSE INSTALL_EXTENSIONS=TRUE ./start_linux.sh --verbose
EXPOSE ${CONTAINER_PORT:-7860} ${CONTAINER_API_PORT:-5000} COPY CMD_FLAGS.txt /home/app/text-generation-webui/
EXPOSE ${CONTAINER_PORT:-7860} ${CONTAINER_API_PORT:-5000} ${CONTAINER_API_STREAM_PORT:-5005}
# set umask to ensure group read / write at runtime # set umask to ensure group read / write at runtime
WORKDIR /home/app/text-generation-webui WORKDIR /home/app/text-generation-webui
CMD umask 0002 && export HOME=/home/app/text-generation-webui && ./start_linux.sh --listen CMD umask 0002 && export HOME=/home/app/text-generation-webui && ./start_linux.sh

View file

@ -4,6 +4,22 @@ services:
build: build:
context: . context: .
args: args:
# Requirements file to use:
# | GPU | requirements file to use |
# |--------|---------|
# | NVIDIA | `requirements.txt` |
# | AMD | `requirements_amd.txt` |
# | CPU only | `requirements_cpu_only.txt` |
# | Apple Intel | `requirements_apple_intel.txt` |
# | Apple Silicon | `requirements_apple_silicon.txt` |
# Default: requirements.txt`
# BUILD_REQUIREMENTS: requirements.txt
# Extension requirements to build:
# BUILD_EXTENSIONS:
# specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus
TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST:-7.5}
BUILD_EXTENSIONS: ${BUILD_EXTENSIONS:-} BUILD_EXTENSIONS: ${BUILD_EXTENSIONS:-}
APP_GID: ${APP_GID:-6972} APP_GID: ${APP_GID:-6972}
APP_UID: ${APP_UID:-6972} APP_UID: ${APP_UID:-6972}
@ -15,4 +31,14 @@ services:
stdin_open: true stdin_open: true
tty: true tty: true
volumes: volumes:
- ./user_data:/home/app/text-generation-webui/user_data - ./cache:/home/app/text-generation-webui/cache
- ./characters:/home/app/text-generation-webui/characters
- ./extensions:/home/app/text-generation-webui/extensions
- ./loras:/home/app/text-generation-webui/loras
- ./logs:/home/app/text-generation-webui/logs
- ./models:/home/app/text-generation-webui/models
- ./presets:/home/app/text-generation-webui/presets
- ./prompts:/home/app/text-generation-webui/prompts
- ./softprompts:/home/app/text-generation-webui/softprompts
- ./training:/home/app/text-generation-webui/training
- ./cloudflared:/etc/cloudflared

View file

@ -1,6 +1,7 @@
# BUILDER # BUILDER
FROM ubuntu:22.04 FROM ubuntu:22.04
WORKDIR /builder WORKDIR /builder
ARG TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX}"
ARG BUILD_EXTENSIONS="${BUILD_EXTENSIONS:-}" ARG BUILD_EXTENSIONS="${BUILD_EXTENSIONS:-}"
ARG APP_UID="${APP_UID:-6972}" ARG APP_UID="${APP_UID:-6972}"
ARG APP_GID="${APP_GID:-6972}" ARG APP_GID="${APP_GID:-6972}"
@ -13,7 +14,8 @@ WORKDIR /home/app/
RUN git clone https://github.com/oobabooga/text-generation-webui.git RUN git clone https://github.com/oobabooga/text-generation-webui.git
WORKDIR /home/app/text-generation-webui WORKDIR /home/app/text-generation-webui
RUN GPU_CHOICE=D LAUNCH_AFTER_INSTALL=FALSE INSTALL_EXTENSIONS=TRUE ./start_linux.sh --verbose RUN GPU_CHOICE=D LAUNCH_AFTER_INSTALL=FALSE INSTALL_EXTENSIONS=TRUE ./start_linux.sh --verbose
EXPOSE ${CONTAINER_PORT:-7860} ${CONTAINER_API_PORT:-5000} COPY /user_data/CMD_FLAGS.txt /home/app/text-generation-webui/user_data
EXPOSE ${CONTAINER_PORT:-7860} ${CONTAINER_API_PORT:-5000} ${CONTAINER_API_STREAM_PORT:-5005}
# set umask to ensure group read / write at runtime # set umask to ensure group read / write at runtime
WORKDIR /home/app/text-generation-webui WORKDIR /home/app/text-generation-webui
CMD umask 0002 && export HOME=/home/app/text-generation-webui && ./start_linux.sh --listen CMD umask 0002 && export HOME=/home/app/text-generation-webui && ./start_linux.sh

View file

@ -4,6 +4,22 @@ services:
build: build:
context: . context: .
args: args:
# Requirements file to use:
# | GPU | requirements file to use |
# |--------|---------|
# | NVIDIA | `requirements.txt` |
# | AMD | `requirements_amd.txt` |
# | CPU only | `requirements_cpu_only.txt` |
# | Apple Intel | `requirements_apple_intel.txt` |
# | Apple Silicon | `requirements_apple_silicon.txt` |
# Default: requirements.txt`
# BUILD_REQUIREMENTS: requirements.txt
# Extension requirements to build:
# BUILD_EXTENSIONS:
# specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus
TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST:-7.5}
BUILD_EXTENSIONS: ${BUILD_EXTENSIONS:-} BUILD_EXTENSIONS: ${BUILD_EXTENSIONS:-}
APP_GID: ${APP_GID:-6972} APP_GID: ${APP_GID:-6972}
APP_UID: ${APP_UID:-6972} APP_UID: ${APP_UID:-6972}

View file

@ -14,7 +14,8 @@ WORKDIR /home/app/
RUN git clone https://github.com/oobabooga/text-generation-webui.git RUN git clone https://github.com/oobabooga/text-generation-webui.git
WORKDIR /home/app/text-generation-webui WORKDIR /home/app/text-generation-webui
RUN GPU_CHOICE=A LAUNCH_AFTER_INSTALL=FALSE INSTALL_EXTENSIONS=TRUE ./start_linux.sh --verbose RUN GPU_CHOICE=A LAUNCH_AFTER_INSTALL=FALSE INSTALL_EXTENSIONS=TRUE ./start_linux.sh --verbose
EXPOSE ${CONTAINER_PORT:-7860} ${CONTAINER_API_PORT:-5000} COPY /user_data/CMD_FLAGS.txt /home/app/text-generation-webui/user_data
EXPOSE ${CONTAINER_PORT:-7860} ${CONTAINER_API_PORT:-5000} ${CONTAINER_API_STREAM_PORT:-5005}
WORKDIR /home/app/text-generation-webui WORKDIR /home/app/text-generation-webui
# set umask to ensure group read / write at runtime # set umask to ensure group read / write at runtime
CMD umask 0002 && export HOME=/home/app/text-generation-webui && ./start_linux.sh --listen CMD umask 0002 && export HOME=/home/app/text-generation-webui && ./start_linux.sh --listen

View file

@ -4,8 +4,22 @@ services:
build: build:
context: . context: .
args: args:
# Requirements file to use:
# | GPU | requirements file to use |
# |--------|---------|
# | NVIDIA | `requirements.txt` |
# | AMD | `requirements_amd.txt` |
# | CPU only | `requirements_cpu_only.txt` |
# | Apple Intel | `requirements_apple_intel.txt` |
# | Apple Silicon | `requirements_apple_silicon.txt` |
# Default: requirements.txt`
# BUILD_REQUIREMENTS: requirements.txt
# Extension requirements to build:
# BUILD_EXTENSIONS:
# specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus # specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus
TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST:-8.6;8.9+PTX} TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST:-7.5}
BUILD_EXTENSIONS: ${BUILD_EXTENSIONS:-} BUILD_EXTENSIONS: ${BUILD_EXTENSIONS:-}
APP_GID: ${APP_GID:-6972} APP_GID: ${APP_GID:-6972}
APP_UID: ${APP_UID:-6972} APP_UID: ${APP_UID:-6972}

View file

@ -2,44 +2,31 @@ Used to have multi-turn conversations with the model.
## Input area ## Input area
The main action buttons are: The following buttons can be found. Note that the hover menu can be replaced with always-visible buttons with the `--chat-buttons` flag.
* **Send**: sends your message and makes the model start a reply. * **Generate**: sends your message and makes the model start a reply.
* **Stop**: stops an ongoing generation as soon as the next token is generated (which can take a while for a slow model). * **Stop**: stops an ongoing generation as soon as the next token is generated (which can take a while for a slow model).
The hover menu (☰) that appears over the chat area contains:
* **Regenerate**: similar to Send, but your last message is used as input instead of the text in the input field. Note that if the temperature/top_p/top_k parameters are low in the "Parameters" tab of the UI, the new reply may end up identical to the previous one.
* **Continue**: makes the model attempt to continue the existing reply. In some cases, the model may simply end the existing turn immediately without generating anything new, but in other cases, it may generate a longer reply. * **Continue**: makes the model attempt to continue the existing reply. In some cases, the model may simply end the existing turn immediately without generating anything new, but in other cases, it may generate a longer reply.
* **Regenerate**: similar to Generate, but your last message is used as input instead of the text in the input field. Note that if the temperature/top_p/top_k parameters are low in the "Parameters" tab of the UI, the new reply may end up identical to the previous one.
* **Remove last reply**: removes the last input/output pair from the history and sends your last message back into the input field. * **Remove last reply**: removes the last input/output pair from the history and sends your last message back into the input field.
* **Replace last reply**: replaces the last reply with whatever you typed into the input field. Useful in conjunction with "Copy last reply" if you want to edit the bot response.
* **Copy last reply**: sends the contents of the bot's last reply to the input field.
* **Impersonate**: makes the model generate a new message on your behalf in the input field, taking into consideration the existing chat history. * **Impersonate**: makes the model generate a new message on your behalf in the input field, taking into consideration the existing chat history.
* **Send dummy message**: adds a new message to the chat history without causing the model to generate a reply. * **Send dummy message**: adds a new message to the chat history without causing the model to generate a reply.
* **Send dummy reply**: adds a new reply to the chat history as if the model had generated this reply. Useful in conjunction with "Send dummy message". * **Send dummy reply**: adds a new reply to the chat history as if the model had generated this reply. Useful in conjunction with "Send dummy message".
* **Send to Notebook**: sends the entire chat prompt up to now to the Notebook tab. * **Start new chat**: starts a new conversation while keeping the old one saved. If you are talking to a character that has a "Greeting" message defined, this message will be automatically added to the new history.
* **Show controls**: checkbox that toggles the visibility of the sidebar controls (Start reply with, Mode, Chat style, etc.). Shortcut: Ctrl+S. * **Send to default**: sends the entire chat prompt up to now to the "Default" tab.
* **Send to notebook**: sends the entire chat prompt up to now to the "Notebook" tab.
The **Show controls** checkbox causes the input fields below the input textbox to disappear. It is useful for making the page fit entirely into view and not scroll.
## Past chats ## Past chats
Allows you to switch between the current and previous conversations with the current character, or between the current and previous instruct conversations (if in "instruct" mode). The available buttons are: Allows you to switch between the current and previous conversations with the current character, or between the current and previous instruct conversations (if in "instruct" mode). The **Rename** menu can be used to give a unique name to the selected conversation, and the 🗑️ button allows you to delete it.
* **Branch**: creates a branch of the current conversation at a specific message. ## Start reply with
* **Rename**: allows you to give a unique name to the selected conversation.
* **🗑️**: deletes the selected conversation.
* **New chat**: starts a new conversation. If you are talking to a character that has a "Greeting" message defined, this message will be automatically added to the new history.
A search field is also available to filter conversations by name. Whatever you type there will appear at the start of every reply by the bot. This is useful to guide the response in the desired direction.
## Sidebar controls
The sidebar (toggled via "Show controls") contains:
* **Start reply with**: whatever you type there will appear at the start of every reply by the bot. This is useful to guide the response in the desired direction.
* **Reasoning effort**: controls the thinking depth for models that support it. Options: low, medium, high.
* **Enable thinking**: enables extended thinking mode for models that support it.
* **Activate web search**: when enabled, the model can search the web for information before replying. You can also set the number of pages to download.
* **Mode**: see below.
* **Chat style**: see below.
* **Command for chat-instruct mode**: the command that is used in chat-instruct mode to query the model to generate a reply on behalf of the character. Can be used creatively to generate specific kinds of responses. Inside this string, `<|character|>` is a placeholder that gets replaced with the bot name, and `<|prompt|>` is a placeholder that gets replaced with the full chat prompt.
## Mode ## Mode
@ -86,7 +73,7 @@ Now that an instruction-following model is defined, we can move on to describing
### Chat ### Chat
Used for talking to the character defined under "Character" tab using a simple chat prompt in this format: Used for talking to the character defined under "Parameters" > "Character" using a simple chat prompt in this format:
``` ```
Chiharu Yamada's Persona: Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology. Chiharu Yamada's Persona: Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology.
@ -96,7 +83,7 @@ You: How are you?
Chiharu Yamada: I'm doing well, thank you for asking! Is there something specific you would like to talk about or ask me? I'm here to help answer any questions you may have. Chiharu Yamada: I'm doing well, thank you for asking! Is there something specific you would like to talk about or ask me? I'm here to help answer any questions you may have.
``` ```
There are 3 adjustable parameters in the "Character" tab being used in this prompt: There are 3 adjustable parameters in "Parameters" > "Character" being used in this prompt:
* The **Context** string appears at the top of the prompt. Most often it describes the bot's personality and adds a few example messages to guide the model towards the desired reply length and format. This string never gets truncated: as the prompt size increases, old messages get removed one at a time until the prompt becomes smaller than the truncation length set under "Parameters" > "Generation" > "Truncate the prompt up to this length". * The **Context** string appears at the top of the prompt. Most often it describes the bot's personality and adds a few example messages to guide the model towards the desired reply length and format. This string never gets truncated: as the prompt size increases, old messages get removed one at a time until the prompt becomes smaller than the truncation length set under "Parameters" > "Generation" > "Truncate the prompt up to this length".
* The **Your name** string appears at the beginning of each user reply. By default, this string is "You". * The **Your name** string appears at the beginning of each user reply. By default, this string is "You".
@ -112,7 +99,7 @@ Used for talking to an instruction-following model using the prompt format defin
The prompt format is defined by the **Instruction template** parameter in "Parameters" > "Instruction template", which represents a Jinja2 template. The prompt format is defined by the **Instruction template** parameter in "Parameters" > "Instruction template", which represents a Jinja2 template.
Note that when you load a model in the "Model" tab, the web UI will try to automatically detect its instruction template (if any), and will update the values under "Parameters" > "Instruction template" accordingly. This is done using a set of regular expressions defined in `user_data/models/config.yaml`. This detection is not guaranteed to be accurate. You should check the model card on Hugging Face to see if you are using the correct prompt format. Note that when you load a model in the "Model" tab, the web UI will try to automatically detect its instruction template (if any), and will update the values under "Parameters" > "Instruction template" accordingly. This is done using a set of regular expressions defined in `models/config.yaml`. This detection is not guaranteed to be accurate. You should check the model card on Hugging Face to see if you are using the correct prompt format.
### Chat-instruct ### Chat-instruct
@ -140,6 +127,8 @@ Here, the command is
Below this command, the regular chat prompt is added, including its Context string and the chat history, and then the user turn ends. The bot turn starts with the "Character's name" string followed by `:`, thus prompting the instruction-following model to write a single reply for the character. Below this command, the regular chat prompt is added, including its Context string and the chat history, and then the user turn ends. The bot turn starts with the "Character's name" string followed by `:`, thus prompting the instruction-following model to write a single reply for the character.
The chat-instruct command can be customized under "Parameters" > "Instruction template" > "Command for chat-instruct mode". Inside that command string, `<|character|>` is a placeholder that gets replaced with the bot name, and `<|prompt|>` is a placeholder that gets replaced with the full chat prompt.
Note that you can get creative: instead of writing something trivial like "Write a single reply for the character", you could add more complex instructions like Note that you can get creative: instead of writing something trivial like "Write a single reply for the character", you could add more complex instructions like
> This is an adventure game, and your task is to write a reply in name of "<|character|>" where 3 options are given for the user to then choose from. > This is an adventure game, and your task is to write a reply in name of "<|character|>" where 3 options are given for the user to then choose from.
@ -156,4 +145,4 @@ The styles are only applied to chat and chat-instruct modes. Instruct mode has i
## Character gallery ## Character gallery
This menu is a built-in extension defined under `text-generation-webui/extensions/gallery`. It displays a gallery with your characters, and if you click on a character, it will be automatically selected in the Character tab. This menu is a built-in extension defined under `text-generation-webui/extensions/gallery`. It displays a gallery with your characters, and if you click on a character, it will be automatically selected in the menu under "Parameters" > "Character".

View file

@ -10,11 +10,11 @@ The number on the lower right of the Input box counts the number of tokens in th
Below the Input box, the following buttons can be found: Below the Input box, the following buttons can be found:
* **Continue**: starts a new generation taking as input the text in the "Output" box.
* **Generate**: starts a new generation. * **Generate**: starts a new generation.
* **Stop**: stops an ongoing generation as soon as the next token is generated (which can take a while for a slow model). * **Stop**: stops an ongoing generation as soon as the next token is generated (which can take a while for a slow model).
* **Continue**: starts a new generation taking as input the text in the "Output" box.
In the **Prompt** menu, you can select from saved prompts stored in `user_data/logs/notebook`. The **New** button creates a new prompt, the **Rename** button renames the selected prompt, and the 🗑️ button deletes it. The 🔄 button refreshes the list. In the **Prompt** menu, you can select from some predefined prompts defined under `text-generation-webui/prompts`. The 💾 button saves your current input as a new prompt, the 🗑️ button deletes the selected prompt, and the 🔄 button refreshes the list. If you come up with an interesting prompt for a certain task, you are welcome to submit it to the repository.
### Output ### Output

View file

@ -43,15 +43,9 @@ For more information about the parameters, the [transformers documentation](http
* **presence_penalty**: Similar to repetition_penalty, but with an additive offset on the raw token scores instead of a multiplicative factor. It may generate better results. 0 means no penalty, higher value = less repetition, lower value = more repetition. Previously called "additive_repetition_penalty". * **presence_penalty**: Similar to repetition_penalty, but with an additive offset on the raw token scores instead of a multiplicative factor. It may generate better results. 0 means no penalty, higher value = less repetition, lower value = more repetition. Previously called "additive_repetition_penalty".
* **frequency_penalty**: Repetition penalty that scales based on how many times the token has appeared in the context. Be careful with this; there's no limit to how much a token can be penalized. * **frequency_penalty**: Repetition penalty that scales based on how many times the token has appeared in the context. Be careful with this; there's no limit to how much a token can be penalized.
* **repetition_penalty_range**: The number of most recent tokens to consider for repetition penalty. 0 makes all tokens be used. * **repetition_penalty_range**: The number of most recent tokens to consider for repetition penalty. 0 makes all tokens be used.
* **dry_multiplier**: Set to greater than 0 to enable DRY (Don't Repeat Yourself) sampling. It penalizes tokens that would extend a sequence that already appeared in the context. Recommended value: 0.8.
* **dry_allowed_length**: The longest sequence that can be repeated without being penalized by DRY. Shorter values make DRY more aggressive.
* **dry_base**: Controls how fast the DRY penalty grows with increasing sequence length.
* **typical_p**: If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text. * **typical_p**: If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text.
* **tfs**: Tries to detect a tail of low-probability tokens in the distribution and removes those tokens. See [this blog post](https://www.trentonbricken.com/Tail-Free-Sampling/) for details. The closer to 0, the more discarded tokens. * **tfs**: Tries to detect a tail of low-probability tokens in the distribution and removes those tokens. See [this blog post](https://www.trentonbricken.com/Tail-Free-Sampling/) for details. The closer to 0, the more discarded tokens.
* **top_a**: Tokens with probability smaller than `(top_a) * (probability of the most likely token)^2` are discarded. * **top_a**: Tokens with probability smaller than `(top_a) * (probability of the most likely token)^2` are discarded.
* **top_n_sigma**: Keeps only tokens within N standard deviations of the mean log-probability. Acts as an adaptive cutoff that adjusts to the shape of the distribution. 0 disables it.
* **xtc_threshold**: eXclusion from Top Choices (XTC) sampling. If 2 or more tokens have probability above this threshold, the top token may be removed. This encourages the model to use less common word choices and can increase creativity.
* **xtc_probability**: The probability that XTC removal will actually happen when the threshold condition is met. Set to 1 for it to always apply, or lower for occasional application.
* **epsilon_cutoff**: In units of 1e-4; a reasonable value is 3. This sets a probability floor below which tokens are excluded from being sampled. * **epsilon_cutoff**: In units of 1e-4; a reasonable value is 3. This sets a probability floor below which tokens are excluded from being sampled.
* **eta_cutoff**: In units of 1e-4; a reasonable value is 3. The main parameter of the special Eta Sampling technique. See [this paper](https://arxiv.org/pdf/2210.15191.pdf) for a description. * **eta_cutoff**: In units of 1e-4; a reasonable value is 3. The main parameter of the special Eta Sampling technique. See [this paper](https://arxiv.org/pdf/2210.15191.pdf) for a description.
* **guidance_scale**: The main parameter for Classifier-Free Guidance (CFG). [The paper](https://arxiv.org/pdf/2306.17806.pdf) suggests that 1.5 is a good value. It can be used in conjunction with a negative prompt or not. * **guidance_scale**: The main parameter for Classifier-Free Guidance (CFG). [The paper](https://arxiv.org/pdf/2306.17806.pdf) suggests that 1.5 is a good value. It can be used in conjunction with a negative prompt or not.
@ -61,62 +55,36 @@ For more information about the parameters, the [transformers documentation](http
*Note: Use either mirostat or dynamic_temperature, not both at the same time.* *Note: Use either mirostat or dynamic_temperature, not both at the same time.*
* **mirostat_tau**: Target perplexity for Mirostat sampling. Controls how “surprising” the text is. Higher values = more diverse, lower = more predictable. Preset Arena suggests 8 as a good value. * **mirostat_tau**: Target perplexity for Mirostat sampling. Controls how “surprising” the text is. Higher values = more diverse, lower = more predictable. Preset Arena suggests 8 as a good value.
* **mirostat_eta**: Learning rate for Mirostats perplexity adjustment. Higher values = adapts faster but less stable, lower values = slower but more stable. Preset Arena suggests 0.1 as a good value. * **mirostat_eta**: Learning rate for Mirostats perplexity adjustment. Higher values = adapts faster but less stable, lower values = slower but more stable. Preset Arena suggests 0.1 as a good value.
* **adaptive_target**: Target probability for adaptive-p sampling. This method adjusts the sampling threshold dynamically based on an exponential moving average of recent token probabilities. 0 disables it.
* **adaptive_decay**: EMA decay rate for adaptive-p sampling. Controls how quickly the running average adjusts. Default: 0.9.
* **dynamic_temperature**: Activates Dynamic Temperature. This modifies temperature to range between "dynatemp_low" (minimum) and "dynatemp_high" (maximum), with an entropy-based scaling. The steepness of the curve is controlled by "dynatemp_exponent". * **dynamic_temperature**: Activates Dynamic Temperature. This modifies temperature to range between "dynatemp_low" (minimum) and "dynatemp_high" (maximum), with an entropy-based scaling. The steepness of the curve is controlled by "dynatemp_exponent".
*Note: Use either dynamic_temperature or mirostat, not both at the same time.* *Note: Use either dynamic_temperature or mirostat, not both at the same time.*
* **smoothing_factor**: Activates Quadratic Sampling. When `0 < smoothing_factor < 1`, the logits distribution becomes flatter. When `smoothing_factor > 1`, it becomes more peaked. * **smoothing_factor**: Activates Quadratic Sampling. When `0 < smoothing_factor < 1`, the logits distribution becomes flatter. When `smoothing_factor > 1`, it becomes more peaked.
* **smoothing_curve**: Adjusts the dropoff curve of Quadratic Sampling. Higher values make the curve steeper. Only takes effect when smoothing_factor is set.
* **temperature_last**: Makes temperature the last sampler instead of the first. With this, you can remove low probability tokens with a sampler like min_p and then use a high temperature to make the model creative without losing coherency. Note: this parameter takes precedence over "Sampler priority". That means that `temperature`/`dynamic_temperature`/`quadratic_sampling` will be removed from wherever they are and moved to the end of the stack. * **temperature_last**: Makes temperature the last sampler instead of the first. With this, you can remove low probability tokens with a sampler like min_p and then use a high temperature to make the model creative without losing coherency. Note: this parameter takes precedence over "Sampler priority". That means that `temperature`/`dynamic_temperature`/`quadratic_sampling` will be removed from wherever they are and moved to the end of the stack.
* **do_sample**: When unchecked, sampling is entirely disabled, and greedy decoding is used instead (the most likely token is always picked). * **do_sample**: When unchecked, sampling is entirely disabled, and greedy decoding is used instead (the most likely token is always picked).
* **Seed**: Set the Pytorch seed to this number. Note that some loaders do not use Pytorch (notably llama.cpp). For these loaders, the seed has no effect. * **Seed**: Set the Pytorch seed to this number. Note that some loaders do not use Pytorch (notably llama.cpp), and others are not deterministic (ExLlamaV2). For these loaders, the seed has no effect.
* **encoder_repetition_penalty**: Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge. * **encoder_repetition_penalty**: Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge.
* **no_repeat_ngram_size**: If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases. * **no_repeat_ngram_size**: If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases.
To the right (or below if you are on mobile), the following parameters are present: To the right (or below if you are on mobile), the following parameters are present:
* **Truncate the prompt up to this length**: Used to prevent the prompt from getting bigger than the model's context length. In the case of the transformers loader, which allocates memory dynamically, this parameter can also be used to set a VRAM ceiling and prevent out-of-memory errors. This parameter is automatically updated with the model's context length (from "ctx_size" for loaders that use this parameter, and from the model metadata directly for loaders that do not) when you load a model. * **Truncate the prompt up to this length**: Used to prevent the prompt from getting bigger than the model's context length. In the case of the transformers loader, which allocates memory dynamically, this parameter can also be used to set a VRAM ceiling and prevent out-of-memory errors. This parameter is automatically updated with the model's context length (from "n_ctx" or "max_seq_len" for loaders that use these parameters, and from the model metadata directly for loaders that do not) when you load a model.
* **Maximum number of tokens/second**: to make text readable in real-time in case the model is generating too fast. Good if you want to flex and tell everyone how good your GPU is. * **Maximum number of tokens/second**: to make text readable in real-time in case the model is generating too fast. Good if you want to flex and tell everyone how good your GPU is.
* **Custom system message**: If not empty, will be used instead of the default system message in the instruction template. Useful for customizing the personality of the chatbot. Example: "You are a duck."
* **Custom stopping strings**: The model stops generating as soon as any of the strings set in this field is generated. Note that when generating text in the Chat tab, some default stopping strings are set regardless of this parameter, like "\nYour Name:" and "\nBot name:" for chat mode. That's why this parameter has a "Custom" in its name. * **Custom stopping strings**: The model stops generating as soon as any of the strings set in this field is generated. Note that when generating text in the Chat tab, some default stopping strings are set regardless of this parameter, like "\nYour Name:" and "\nBot name:" for chat mode. That's why this parameter has a "Custom" in its name.
* **Custom token bans**: Allows you to ban the model from generating certain tokens altogether. You need to find the token IDs under "Default" > "Tokens" or "Notebook" > "Tokens", or by looking at the `tokenizer.json` for the model directly. * **Custom token bans**: Allows you to ban the model from generating certain tokens altogether. You need to find the token IDs under "Default" > "Tokens" or "Notebook" > "Tokens", or by looking at the `tokenizer.json` for the model directly.
* **auto_max_new_tokens**: When checked, the max_new_tokens parameter is expanded in the backend to the available context length. The maximum length is given by the "truncation_length" parameter. This is useful for getting long replies in the Chat tab without having to click on "Continue" many times. * **auto_max_new_tokens**: When checked, the max_new_tokens parameter is expanded in the backend to the available context length. The maximum length is given by the "truncation_length" parameter. This is useful for getting long replies in the Chat tab without having to click on "Continue" many times.
* **Ban the eos_token**: One of the possible tokens that a model can generate is the EOS (End of Sequence) token. When it is generated, the generation stops prematurely. When this parameter is checked, that token is banned from being generated, and the generation will always generate "max_new_tokens" tokens. * **Ban the eos_token**: One of the possible tokens that a model can generate is the EOS (End of Sequence) token. When it is generated, the generation stops prematurely. When this parameter is checked, that token is banned from being generated, and the generation will always generate "max_new_tokens" tokens.
* **Add the bos_token to the beginning of prompts**: By default, the tokenizer will add a BOS (Beginning of Sequence) token to your prompt. During training, BOS tokens are used to separate different documents. If unchecked, no BOS token will be added, and the model will interpret your prompt as being in the middle of a document instead of at the start of one. This significantly changes the output and can make it more creative. * **Add the bos_token to the beginning of prompts**: By default, the tokenizer will add a BOS (Beginning of Sequence) token to your prompt. During training, BOS tokens are used to separate different documents. If unchecked, no BOS token will be added, and the model will interpret your prompt as being in the middle of a document instead of at the start of one. This significantly changes the output and can make it more creative.
* **Skip special tokens**: When decoding the generated tokens, skip special tokens from being converted to their text representation. Otherwise, BOS appears as `<s>`, EOS as `</s>`, etc. * **Skip special tokens**: When decoding the generated tokens, skip special tokens from being converted to their text representation. Otherwise, BOS appears as `<s>`, EOS as `</s>`, etc.
* **prompt_lookup_num_tokens**: Activates Prompt Lookup Decoding, a form of speculative decoding for the Transformers loader. It guesses future tokens by looking for matching patterns in the prompt itself, which can speed up generation for tasks that involve repeating or paraphrasing parts of the input.
* **Activate text streaming**: When unchecked, the full response is outputted at once, without streaming the words one at a time. I recommend unchecking this parameter on high latency networks like running the webui on Google Colab or using `--share`. * **Activate text streaming**: When unchecked, the full response is outputted at once, without streaming the words one at a time. I recommend unchecking this parameter on high latency networks like running the webui on Google Colab or using `--share`.
* **Static KV cache**: Use a static cache for improved performance with the Transformers loader. May not be compatible with all models.
* **Sampler priority**: Allows you to customize the order in which the different samplers are applied. The first sampler on the list gets applied first. With this, custom orders like `top_p -> temperature -> top_k` can be defined. * **Sampler priority**: Allows you to customize the order in which the different samplers are applied. The first sampler on the list gets applied first. With this, custom orders like `top_p -> temperature -> top_k` can be defined.
* **DRY sequence breakers**: Tokens across which DRY sequence matching is not continued. Typically punctuation and special tokens. Only used when DRY is active (dry_multiplier > 0). * **Load grammar from file**: Loads a GBNF grammar from a file under `text-generation-webui/grammars`. The output is written to the "Grammar" box below. You can also save and delete custom grammars using this menu.
* **Load grammar from file**: Loads a GBNF grammar from a file under `user_data/grammars`. The output is written to the "Grammar" box below. You can also save and delete custom grammars using this menu.
* **Grammar**: Allows you to constrain the model output to a particular format. For instance, you can make the model generate lists, JSON, specific words, etc. Grammar is extremely powerful and I highly recommend it. The syntax looks a bit daunting at first sight, but it gets very easy once you understand it. See the [GBNF Guide](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md) for details. * **Grammar**: Allows you to constrain the model output to a particular format. For instance, you can make the model generate lists, JSON, specific words, etc. Grammar is extremely powerful and I highly recommend it. The syntax looks a bit daunting at first sight, but it gets very easy once you understand it. See the [GBNF Guide](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md) for details.
### Chat tab controls ## Character
The following parameters appear in the Chat tab sidebar rather than the Parameters tab: Parameters that define the character that is used in the Chat tab when "chat" or "chat-instruct" are selected under "Mode".
* **reasoning_effort**: Controls the thinking depth for models that support it (used by GPT-OSS). Options: low, medium, high. * **Character**: A dropdown menu where you can select from saved characters, save a new character (💾 button), and delete the selected character (🗑️).
* **enable_thinking**: Enables extended thinking mode for models that support it (used by Seed-OSS and pre-2507 Qwen3). When enabled, the model can use a thinking step before generating its reply. * **Your name**: Your name as it appears in the prompt.
## Instruction template
This sub-tab within the Parameters tab defines the instruction template used in the Chat tab when "instruct" or "chat-instruct" are selected under "Mode".
* **Saved instruction templates**: A dropdown menu where you can select a template. Click **Load** to apply it. The 💾 button saves the current template, and the 🗑️ button deletes the selected one.
* **Instruction template**: A Jinja2 template that defines the prompt format for the instruction-following conversation.
* **Send to notebook**: Send the full instruction template in string format to the Notebook tab.
* **Chat template**: A Jinja2 template that defines the prompt format for regular chat conversations with characters.
## Character tab
The Character tab is a separate top-level tab that contains the following sub-tabs:
### Character
Parameters that define the character used in the Chat tab when "chat" or "chat-instruct" are selected under "Mode".
* **Character**: A dropdown menu where you can select from saved characters, save a new character (💾 button), and delete the selected character (🗑️). The **Restore character** button resets the character to its last saved state.
* **Character's name**: The bot name as it appears in the prompt. * **Character's name**: The bot name as it appears in the prompt.
* **Context**: A string that is always at the top of the prompt. It never gets truncated. It usually defines the bot's personality and some key elements of the conversation. * **Context**: A string that is always at the top of the prompt. It never gets truncated. It usually defines the bot's personality and some key elements of the conversation.
* **Greeting**: An opening message for the bot. When set, it appears whenever you start a new chat. * **Greeting**: An opening message for the bot. When set, it appears whenever you start a new chat.
@ -130,26 +98,31 @@ Note: the following replacements take place in the context and greeting fields w
So you can use those special placeholders in your character definitions. They are commonly found in TavernAI character cards. So you can use those special placeholders in your character definitions. They are commonly found in TavernAI character cards.
### User ## Instruction template
Allows you to create and manage user profiles. Defines the instruction template that is used in the Chat tab when "instruct" or "chat-instruct" are selected under "Mode".
* **User**: A dropdown to select, save (💾), or delete (🗑️) user profiles. * **Saved instruction templates**: A dropdown menu where you can load a saved template, save a new template (💾 button), and delete the currently selected template (🗑️).
* **Name**: Your name as it appears in the prompt. * **Custom system message**: A message that defines the personality of the chatbot, replacing its default "System message" string. Example: "You are a duck."
* **Description**: An optional description of yourself that can be referenced in conversations. * **Instruction template**: A Jinja2 template that defines the prompt format for the instruction-following conversation.
* **Send to default**: Send the full instruction template in string format to the Default tab.
* **Send to notebook**: Send the full instruction template in string format to the Notebook tab.
* **Send to negative prompt**: Send the full instruction template in string format to the "Negative prompt" field under "Parameters" > "Generation".
* **Chat template**: A Jinja2 template that defines the prompt format for regular chat conversations with characters.
* **Command for chat-instruct mode**: The command that is used in chat-instruct mode to query the model to generate a reply on behalf of the character. Can be used creatively to generate specific kinds of responses.
### Chat history ## Chat history
In this tab, you can download the current chat history in JSON format and upload a previously saved chat history. In this tab, you can download the current chat history in JSON format and upload a previously saved chat history.
When a history is uploaded, a new chat is created to hold it. That is, you don't lose your current chat in the Chat tab. When a history is uploaded, a new chat is created to hold it. That is, you don't lose your current chat in the Chat tab.
### Upload character ## Upload character
#### YAML or JSON ### YAML or JSON
Allows you to upload characters in the YAML format used by the web UI, including optionally a profile picture. Allows you to upload characters in the YAML format used by the web UI, including optionally a profile picture.
#### TavernAI PNG ### TavernAI PNG
Allows you to upload a TavernAI character card. It will be converted to the internal YAML format of the web UI after upload. Allows you to upload a TavernAI character card. It will be converted to the internal YAML format of the web UI after upload.

View file

@ -2,89 +2,112 @@ This is where you load models, apply LoRAs to a loaded model, and download new m
## Model loaders ## Model loaders
### Transformers
Loads: full precision (16-bit or 32-bit) models. The repository usually has a clean name without GGUF, EXL2, GPTQ, or AWQ in its name, and the model files are named `pytorch_model.bin` or `model.safetensors`.
Example: [https://huggingface.co/lmsys/vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5).
Full precision models use a ton of VRAM, so you will usually want to select the "load_in_4bit" and "use_double_quant" options to load the model in 4-bit precision using bitsandbytes.
This loader can also load GPTQ models and train LoRAs with them. For that, make sure to check the "auto-devices" and "disable_exllama" options before loading the model.
Options:
* **gpu-memory**: When set to greater than 0, activates CPU offloading using the accelerate library, where part of the layers go to the CPU. The performance is very bad. Note that accelerate doesn't treat this parameter very literally, so if you want the VRAM usage to be at most 10 GiB, you may need to set this parameter to 9 GiB or 8 GiB. It can be used in conjunction with "load_in_8bit" but not with "load-in-4bit" as far as I'm aware.
* **cpu-memory**: Similarly to the parameter above, you can also set a limit on the amount of CPU memory used. Whatever doesn't fit either in the GPU or the CPU will go to a disk cache, so to use this option you should also check the "disk" checkbox.
* **compute_dtype**: Used when "load-in-4bit" is checked. I recommend leaving the default value.
* **quant_type**: Used when "load-in-4bit" is checked. I recommend leaving the default value.
* **alpha_value**: Used to extend the context length of a model with a minor loss in quality. I have measured 1.75 to be optimal for 1.5x context, and 2.5 for 2x context. That is, with alpha = 2.5 you can make a model with 4096 context length go to 8192 context length.
* **rope_freq_base**: Originally another way to write "alpha_value", it ended up becoming a necessary parameter for some models like CodeLlama, which was fine-tuned with this set to 1000000 and hence needs to be loaded with it set to 1000000 as well.
* **compress_pos_emb**: The first and original context-length extension method, discovered by [kaiokendev](https://kaiokendev.github.io/til). When set to 2, the context length is doubled, 3 and it's tripled, etc. It should only be used for models that have been fine-tuned with this parameter set to different than 1. For models that have not been tuned to have greater context length, alpha_value will lead to a smaller accuracy loss.
* **cpu**: Loads the model in CPU mode using Pytorch. The model will be loaded in 32-bit precision, so a lot of RAM will be used. CPU inference with transformers is older than llama.cpp and it works, but it's a lot slower. Note: this parameter has a different interpretation in the llama.cpp loader (see below).
* **load-in-8bit**: Load the model in 8-bit precision using bitsandbytes. The 8-bit kernel in that library has been optimized for training and not inference, so load-in-8bit is slower than load-in-4bit (but more accurate).
* **bf16**: Use bfloat16 precision instead of float16 (the default). Only applies when quantization is not used.
* **auto-devices**: When checked, the backend will try to guess a reasonable value for "gpu-memory" to allow you to load a model with CPU offloading. I recommend just setting "gpu-memory" manually instead. This parameter is also needed for loading GPTQ models, in which case it needs to be checked before loading the model.
* **disk**: Enable disk offloading for layers that don't fit into the GPU and CPU combined.
* **load-in-4bit**: Load the model in 4-bit precision using bitsandbytes.
* **trust-remote-code**: Some models use custom Python code to load the model or the tokenizer. For such models, this option needs to be set. It doesn't download any remote content: all it does is execute the .py files that get downloaded with the model. Those files can potentially include malicious code; I have never seen it happen, but it is in principle possible.
* **no_use_fast**: Do not use the "fast" version of the tokenizer. Can usually be ignored; only check this if you can't load the tokenizer for your model otherwise.
* **use_flash_attention_2**: Set use_flash_attention_2=True while loading the model. Possibly useful for training.
* **disable_exllama**: Only applies when you are loading a GPTQ model through the transformers loader. It needs to be checked if you intend to train LoRAs with the model.
### ExLlamav2_HF
Loads: GPTQ and EXL2 models. EXL2 models usually have "EXL2" in the model name, while GPTQ models usually have GPTQ in the model name, or alternatively something like "-4bit-128g" in the name.
Examples:
* https://huggingface.co/turboderp/Llama2-70B-exl2
* https://huggingface.co/TheBloke/Llama-2-13B-chat-GPTQ
* **gpu-split**: If you have multiple GPUs, the amount of memory to allocate per GPU should be set in this field. Make sure to set a lower value for the first GPU, as that's where the cache is allocated.
* **max_seq_len**: The maximum sequence length for the model. In ExLlamaV2, the cache is preallocated, so the higher this value, the higher the VRAM. It is automatically set to the maximum sequence length for the model based on its metadata, but you may need to lower this value be able to fit the model into your GPU. After loading the model, the "Truncate the prompt up to this length" parameter under "Parameters" > "Generation" is automatically set to your chosen "max_seq_len" so that you don't have to set the same thing twice.
* **cfg-cache**: Creates a second cache to hold the CFG negative prompts. You need to set this if and only if you intend to use CFG in the "Parameters" > "Generation" tab. Checking this parameter doubles the cache VRAM usage.
* **no_flash_attn**: Disables flash attention. Otherwise, it is automatically used as long as the library is installed.
* **cache_8bit**: Create a 8-bit precision cache instead of a 16-bit one. This saves VRAM but increases perplexity (I don't know by how much).
* **cache_4bit**: Creates a Q4 cache using grouped quantization.
### ExLlamav2
The same as ExLlamav2_HF but using the internal samplers of ExLlamav2 instead of the ones in the Transformers library.
### AutoGPTQ
Loads: GPTQ models.
* **wbits**: For ancient models without proper metadata, sets the model precision in bits manually. Can usually be ignored.
* **groupsize**: For ancient models without proper metadata, sets the model group size manually. Can usually be ignored.
* **triton**: Only available on Linux. Necessary to use models with both act-order and groupsize simultaneously. Note that ExLlamaV2 can load these same models on Windows without triton.
* **no_inject_fused_attention**: Improves performance while increasing the VRAM usage.
* **no_inject_fused_mlp**: Similar to the previous parameter but for Triton only.
* **no_use_cuda_fp16**: On some systems, the performance can be very bad with this unset. Can usually be ignored.
* **desc_act**: For ancient models without proper metadata, sets the model "act-order" parameter manually. Can usually be ignored.
### llama.cpp ### llama.cpp
Loads: GGUF models. Note: GGML models have been deprecated and do not work anymore. Loads: GGUF models. Note: GGML models have been deprecated and do not work anymore.
Example: https://huggingface.co/TheBloke/Llama-2-7b-Chat-GGUF Example: https://huggingface.co/TheBloke/Llama-2-7b-Chat-GGUF
* **gpu_layers**: The number of layers to allocate to the GPU. If set to 0, only the CPU will be used. If you want to offload all layers, you can simply set this to the maximum value. * **n-gpu-layers**: The number of layers to allocate to the GPU. If set to 0, only the CPU will be used. If you want to offload all layers, you can simply set this to the maximum value.
* **ctx_size**: Context length of the model. In llama.cpp, the cache is preallocated, so the higher this value, the higher the VRAM. It is automatically set to the maximum sequence length for the model based on the metadata inside the GGUF file, but you may need to lower this value to fit the model into your GPU. Set to 0 for automatic context size based on available memory. After loading the model, the "Truncate the prompt up to this length" parameter under "Parameters" > "Generation" is automatically set to your chosen "ctx_size" so that you don't have to set the same thing twice. * **n_ctx**: Context length of the model. In llama.cpp, the cache is preallocated, so the higher this value, the higher the VRAM. It is automatically set to the maximum sequence length for the model based on the metadata inside the GGUF file, but you may need to lower this value be able to fit the model into your GPU. After loading the model, the "Truncate the prompt up to this length" parameter under "Parameters" > "Generation" is automatically set to your chosen "n_ctx" so that you don't have to set the same thing twice.
* **cache_type**: KV cache quantization type. Valid options: `fp16`, `q8_0`, `q4_0`. Lower quantization saves VRAM at the cost of some quality.
* **tensor_split**: For multi-gpu only. Sets the amount of memory to allocate per GPU as proportions. Not to be confused with other loaders where this is set in GB; here you can set something like `30,70` for 30%/70%. * **tensor_split**: For multi-gpu only. Sets the amount of memory to allocate per GPU as proportions. Not to be confused with other loaders where this is set in GB; here you can set something like `30,70` for 30%/70%.
* **batch_size**: Maximum number of prompt tokens to batch together when calling llama_eval. * **n_batch**: Batch size for prompt processing. Higher values are supposed to make generation faster, but I have never obtained any benefit from changing this value.
* **ubatch_size**: Physical maximum batch size for prompt processing.
* **threads**: Number of threads. Recommended value: your number of physical cores. * **threads**: Number of threads. Recommended value: your number of physical cores.
* **threads_batch**: Number of threads for batch processing. Recommended value: your total number of cores (physical + virtual). * **threads_batch**: Number of threads for batch processing. Recommended value: your total number of cores (physical + virtual).
* **cpu_moe**: Force MoE expert layers to run on the CPU, keeping the rest on the GPU. * **tensorcores**: Use llama.cpp compiled with "tensor cores" support, which improves performance on NVIDIA RTX cards in most cases.
* **extra_flags**: Extra flags to pass to llama-server. Format: `flag1=value1,flag2,flag3=value3`. Example: `override-tensor=exps=CPU`. * **streamingllm**: Experimental feature to avoid re-evaluating the entire prompt when part of it is removed, for instance, when you hit the context length for the model in chat mode and an old message is removed.
* **mmproj**: Path to the mmproj file for multimodal (vision) models. This enables image understanding capabilities.
* **streaming_llm**: Experimental feature to avoid re-evaluating the entire prompt when part of it is removed, for instance, when you hit the context length for the model in chat mode and an old message is removed.
* **cpu**: Force a version of llama.cpp compiled without GPU acceleration to be used. Can usually be ignored. Only set this if you want to use CPU only and llama.cpp doesn't work otherwise. * **cpu**: Force a version of llama.cpp compiled without GPU acceleration to be used. Can usually be ignored. Only set this if you want to use CPU only and llama.cpp doesn't work otherwise.
* **row_split**: Split the model by rows across GPUs. This may improve multi-gpu performance. * **no_mul_mat_q**: Disable the mul_mat_q kernel. This kernel usually improves generation speed significantly. This option to disable it is included in case it doesn't work on some system.
* **no_kv_offload**: Do not offload the KV cache to the GPU. This saves VRAM but reduces performance. * **no-mmap**: Loads the model into memory at once, possibly preventing I/O operations later on at the cost of a longer load time.
* **no_mmap**: Loads the model into memory at once, possibly preventing I/O operations later on at the cost of a longer load time. * **mlock**: Force the system to keep the model in RAM rather than swapping or compressing (no idea what this means, never used it).
* **mlock**: Force the system to keep the model in RAM rather than swapping or compressing.
* **numa**: May improve performance on certain multi-cpu systems. * **numa**: May improve performance on certain multi-cpu systems.
### Transformers ### llamacpp_HF
Loads: full precision (16-bit or 32-bit) models, as well as bitsandbytes-quantized models. The repository usually has a clean name without GGUF or EXL3 in its name, and the model files are named `model.safetensors` or split into parts like `model-00001-of-00004.safetensors`. The same as llama.cpp but with transformers samplers, and using the transformers tokenizer instead of the internal llama.cpp tokenizer.
Example: [https://huggingface.co/lmsys/vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5). To use it, you need to download a tokenizer. There are two options:
Full precision models use a ton of VRAM, so you will usually want to select the "load_in_4bit" and "use_double_quant" options to load the model in 4-bit precision using bitsandbytes. 1) Download `oobabooga/llama-tokenizer` under "Download model or LoRA". That's a default Llama tokenizer.
2) Place your .gguf in a subfolder of `models/` along with these 3 files: `tokenizer.model`, `tokenizer_config.json`, and `special_tokens_map.json`. This takes precedence over Option 1.
Options: It has an additional parameter:
* **gpu_split**: When using multiple GPUs, sets the amount of VRAM in GB to allocate per GPU. Example: `20,7,7`. * **logits_all**: Needs to be checked if you want to evaluate the perplexity of the llama.cpp model using the "Training" > "Perplexity evaluation" tab. Otherwise, leave it unchecked, as it makes prompt processing slower.
* **cpu_memory**: Maximum CPU memory in GiB to use for CPU offloading via the accelerate library. Whatever doesn't fit in the GPU or CPU will go to a disk cache if the "disk" checkbox is enabled.
* **compute_dtype**: Used when "load_in_4bit" is checked. I recommend leaving the default value.
* **quant_type**: Used when "load_in_4bit" is checked. I recommend leaving the default value.
* **attn_implementation**: Choose the attention implementation. Valid options: `sdpa`, `eager`, `flash_attention_2`. The default (`sdpa`) works well in most cases; `flash_attention_2` may be useful for training.
* **cpu**: Loads the model in CPU mode using Pytorch. The model will be loaded in 32-bit precision, so a lot of RAM will be used. CPU inference with transformers is older than llama.cpp and it works, but it's a lot slower. Note: this parameter has a different interpretation in the llama.cpp loader (see above).
* **load_in_8bit**: Load the model in 8-bit precision using bitsandbytes. The 8-bit kernel in that library has been optimized for training and not inference, so load_in_8bit is slower than load_in_4bit (but more accurate).
* **bf16**: Use bfloat16 precision instead of float16 (the default). Only applies when quantization is not used.
* **disk**: Enable disk offloading for layers that don't fit into the GPU and CPU combined.
* **load_in_4bit**: Load the model in 4-bit precision using bitsandbytes.
* **use_double_quant**: Use double quantization with 4-bit loading for reduced memory usage.
* **trust-remote-code**: Some models use custom Python code to load the model or the tokenizer. For such models, this option needs to be set. It doesn't download any remote content: all it does is execute the .py files that get downloaded with the model. Those files can potentially include malicious code; I have never seen it happen, but it is in principle possible.
* **no_use_fast**: Do not use the "fast" version of the tokenizer. Can usually be ignored; only check this if you can't load the tokenizer for your model otherwise.
### ExLlamav3_HF ### AutoAWQ
Loads: EXL3 models. These models usually have "EXL3" or "exl3" in the model name. Loads: AWQ models.
Uses the ExLlamaV3 backend with Transformers samplers. Example: https://huggingface.co/TheBloke/Phind-CodeLlama-34B-v2-AWQ
* **ctx_size**: Context length of the model. The cache is preallocated, so the higher this value, the higher the VRAM. It is automatically set to the maximum sequence length for the model based on its metadata, but you may need to lower this value to fit the model into your GPU. After loading the model, the "Truncate the prompt up to this length" parameter under "Parameters" > "Generation" is automatically set to your chosen "ctx_size" so that you don't have to set the same thing twice. The parameters are overall similar to AutoGPTQ.
* **cache_type**: KV cache quantization type. Valid options: `fp16`, `q2` to `q8`. You can also specify key and value bits separately, e.g. `q4_q8`. Lower quantization saves VRAM at the cost of some quality.
* **gpu_split**: Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: `20,7,7`.
* **cfg_cache**: Creates a second cache to hold the CFG negative prompts. You need to set this if and only if you intend to use CFG in the "Parameters" > "Generation" tab. Checking this parameter doubles the cache VRAM usage.
* **no_use_fast**: Do not use the "fast" version of the tokenizer.
* **enable_tp**: Enable Tensor Parallelism (TP) to split the model across GPUs.
* **tp_backend**: The backend for tensor parallelism. Valid options: `native`, `nccl`. Default: `native`.
### ExLlamav3
The same as ExLlamav3_HF but using the internal samplers of ExLlamaV3 instead of the ones in the Transformers library. Supports speculative decoding with a draft model. Also supports multimodal (vision) models natively.
* **ctx_size**: Same as ExLlamav3_HF.
* **cache_type**: Same as ExLlamav3_HF.
* **gpu_split**: Same as ExLlamav3_HF.
* **enable_tp**: Enable Tensor Parallelism (TP) to split the model across GPUs.
* **tp_backend**: The backend for tensor parallelism. Valid options: `native`, `nccl`. Default: `native`.
### TensorRT-LLM
Loads: TensorRT-LLM engine models. These are highly optimized models compiled specifically for NVIDIA GPUs.
* **ctx_size**: Context length of the model.
* **cpp_runner**: Use the ModelRunnerCpp runner, which is faster than the default ModelRunner but doesn't support streaming yet.
## Model dropdown ## Model dropdown
Here you can select a model to be loaded, refresh the list of available models, load/unload/reload the selected model, and save the settings for the model. The "settings" are the values in the input fields (checkboxes, sliders, dropdowns) below this dropdown. Here you can select a model to be loaded, refresh the list of available models (🔄), load/unload/reload the selected model, and save the settings for the model. The "settings" are the values in the input fields (checkboxes, sliders, dropdowns) below this dropdown.
After saving, those settings will get restored whenever you select that model again in the dropdown menu. After saving, those settings will get restored whenever you select that model again in the dropdown menu.
@ -92,14 +115,14 @@ If the **Autoload the model** checkbox is selected, the model will be loaded as
## LoRA dropdown ## LoRA dropdown
Used to apply LoRAs to the model. Note that LoRA support is not implemented for all loaders. Check the [What Works](https://github.com/oobabooga/text-generation-webui/wiki/What-Works) page for details. Used to apply LoRAs to the model. Note that LoRA support is not implemented for all loaders. Check this [page](https://github.com/oobabooga/text-generation-webui/wiki) for details.
## Download model or LoRA ## Download model or LoRA
Here you can download a model or LoRA directly from the https://huggingface.co/ website. Here you can download a model or LoRA directly from the https://huggingface.co/ website.
* Models will be saved to `user_data/models`. * Models will be saved to `text-generation-webui/models`.
* LoRAs will be saved to `user_data/loras`. * LoRAs will be saved to `text-generation-webui/loras`.
In the input field, you can enter either the Hugging Face username/model path (like `facebook/galactica-125m`) or the full model URL (like `https://huggingface.co/facebook/galactica-125m`). To specify a branch, add it at the end after a ":" character like this: `facebook/galactica-125m:main`. In the input field, you can enter either the Hugging Face username/model path (like `facebook/galactica-125m`) or the full model URL (like `https://huggingface.co/facebook/galactica-125m`). To specify a branch, add it at the end after a ":" character like this: `facebook/galactica-125m:main`.

View file

@ -1,121 +1,139 @@
## Training Your Own LoRAs ## Training Your Own LoRAs
A LoRA is tied to a specific model architecture — a LoRA trained on Llama 3 8B won't work on Mistral 7B. Train on the exact model you plan to use. The WebUI seeks to make training your own LoRAs as easy as possible. It comes down to just a few simple steps:
### Quick Start ### **Step 1**: Make a plan.
- What base model do you want to use? The LoRA you make has to be matched up to a single architecture (eg LLaMA-13B) and cannot be transferred to others (eg LLaMA-7B, StableLM, etc. would all be different). Derivatives of the same model (eg Alpaca finetune of LLaMA-13B) might be transferrable, but even then it's best to train exactly on what you plan to use.
- What are you training it on? Do you want it to learn real information, a simple format, ...?
1. Load your base model with the **Transformers** loader (no LoRAs loaded). ### **Step 2**: Gather a dataset.
2. Open the **Training** tab > **Train LoRA**. - If you use a dataset similar to the [Alpaca](https://github.com/gururise/AlpacaDataCleaned/blob/main/alpaca_data_cleaned.json) format, that is natively supported by the `Formatted Dataset` input in the WebUI, with premade formatter options.
3. Pick a dataset and configure parameters (see [below](#parameters)). - If you use a dataset that isn't matched to Alpaca's format, but uses the same basic JSON structure, you can make your own format file by copying `training/formats/alpaca-format.json` to a new file and [editing its content](#format-files).
4. Click **Start LoRA Training** and monitor the [loss](#loss). - If you can get the dataset into a simple text file, that works too! You can train using the `Raw text file` input option.
5. When done, load the LoRA from the **Models** tab and test it. - This means you can for example just copy/paste a chatlog/documentation page/whatever you want, shove it in a plain text file, and train on it.
- If you use a structured dataset not in this format, you may have to find an external way to convert it - or open an issue to request native support.
### Resuming Training ### **Step 3**: Do the training.
- **3.1**: Load the WebUI, and your model.
- Make sure you don't have any LoRAs already loaded (unless you want to train for multi-LoRA usage).
- **3.2**: Open the `Training` tab at the top, `Train LoRA` sub-tab.
- **3.3**: Fill in the name of the LoRA, select your dataset in the dataset options.
- **3.4**: Select other parameters to your preference. See [parameters below](#parameters).
- **3.5**: click `Start LoRA Training`, and wait.
- It can take a few hours for a large dataset, or just a few minute if doing a small run.
- You may want to monitor your [loss value](#loss) while it goes.
To resume from a checkpoint, use the same LoRA name and uncheck `Override Existing Files`. If checkpoints exist (from `Save every n steps`), training will automatically resume from the latest one with full optimizer and scheduler state preserved. Note that you cannot change the `Rank` of an already created LoRA. ### **Step 4**: Evaluate your results.
- Load the LoRA under the Models Tab.
- You can go test-drive it on the `Text generation` tab, or you can use the `Perplexity evaluation` sub-tab of the `Training` tab.
- If you used the `Save every n steps` option, you can grab prior copies of the model from sub-folders within the LoRA model's folder and try them instead.
You should also use `Copy parameters from` to restore the UI settings (learning rate, epochs, etc.) from the previous run, so that training continues with the same configuration. ### **Step 5**: Re-run if you're unhappy.
- Make sure to unload the LoRA before training it.
- You can simply resume a prior run - use `Copy parameters from` to select your LoRA, and edit parameters. Note that you cannot change the `Rank` of an already created LoRA.
- If you want to resume from a checkpoint saved along the way, simply copy the contents of the checkpoint folder into the LoRA's folder.
- (Note: `adapter_model.bin` is the important file that holds the actual LoRA content).
- This will start Learning Rate and Steps back to the start. If you want to resume as if you were midway through, you can adjust your Learning Rate to the last reported LR in logs and reduce your epochs.
- Or, you can start over entirely if you prefer.
- If your model is producing corrupted outputs, you probably need to start over and use a lower Learning Rate.
- If your model isn't learning detailed information but you want it to, you might need to just run more epochs, or you might need a higher Rank.
- If your model is enforcing a format you didn't want, you may need to tweak your dataset, or start over and not train as far.
### Troubleshooting ## Format Files
- **Corrupted outputs**: Start over with a lower Learning Rate. If using JSON formatted datasets, they are presumed to be in the following approximate format:
- **Not learning enough**: Run more epochs, or increase the Rank.
- **Unwanted formatting**: Tweak your dataset, or train for fewer steps.
## Instruction Templates
All instruction/chat training uses `apply_chat_template()` with Jinja2 templates. You have two options in the **Instruction Template** dropdown:
- **Chat Template**: Uses the model's built-in chat template from its tokenizer. Works with instruct/chat models that ship with a chat template (Llama 3, Qwen, Mistral, etc.).
- **Named template** (e.g. ChatML, Alpaca, Llama-v3, etc.): Loads a Jinja2 template from `user_data/instruction-templates/`. This is useful for base models that don't have a built-in template, or when you want to override the model's default template.
Both options are functionally identical — the only difference is where the Jinja2 template string comes from. In both cases:
- The dataset is tokenized via `apply_chat_template()`
- Labels are automatically masked so only assistant responses are trained on
- Multi-turn conversations are supported natively
- Special tokens are handled correctly by the template
The WebUI ships with 50+ templates in `user_data/instruction-templates/`. You can also add your own by creating a `.yaml` file with an `instruction_template` key containing a Jinja2 template string, or a plain `.jinja` file.
**Dataset formats:** Your JSON dataset can use either of these structures:
OpenAI messages format:
```json ```json
[ [
{ {
"messages": [ "somekey": "somevalue",
{"role": "system", "content": "You are a helpful assistant."}, "key2": "value2"
{"role": "user", "content": "What is Python?"}, },
{"role": "assistant", "content": "A programming language."}, {
{"role": "user", "content": "What's it used for?"}, // etc
{"role": "assistant", "content": "Web dev, data science, scripting, and more."}
]
} }
] ]
``` ```
ShareGPT format (`conversations` key with `from`/`value` fields): Where the keys (eg `somekey`, `key2` above) are standardized, and relatively consistent across the dataset, and the values (eg `somevalue`, `value2`) contain the content actually intended to be trained.
```json
[
{
"conversations": [
{"from": "system", "value": "You are a helpful assistant."},
{"from": "human", "value": "What is Python?"},
{"from": "gpt", "value": "A programming language."},
{"from": "human", "value": "What's it used for?"},
{"from": "gpt", "value": "Web dev, data science, scripting, and more."}
]
}
]
```
## Text Dataset For Alpaca, the keys are `instruction`, `input`, and `output`, wherein `input` is sometimes blank.
For pretraining-style training on raw text, use the **Text Dataset** tab. Your dataset should be a JSON file with one document per row, each with a `"text"` key: A simple format file for Alpaca to be used as a chat bot is:
```json ```json
[ {
{"text": "First document content..."}, "instruction,output": "User: %instruction%\nAssistant: %output%",
{"text": "Second document content..."} "instruction,input,output": "User: %instruction%: %input%\nAssistant: %output%"
] }
``` ```
This is the standard format used by most pretraining datasets (The Pile, RedPajama, etc.). Note that the keys (eg `instruction,output`) are a comma-separated list of dataset keys, and the values are a simple string that use those keys with `%%`.
Each document is tokenized (with BOS token), concatenated into one long token sequence, and split into chunks of `Cutoff Length` tokens. The final chunk is padded if shorter than the cutoff length. When `Add EOS token` is enabled, an EOS token is appended after each document before concatenation, helping the model learn document boundaries. So for example if a dataset has `"instruction": "answer my question"`, then the format file's `User: %instruction%\n` will be automatically filled in as `User: answer my question\n`.
- `Stride Length` controls the overlap between consecutive chunks in tokens. Set to 0 for non-overlapping chunks (the standard concatenate-and-split approach). Values like 256 or 512 create overlapping chunks that help the model learn context across chunk boundaries, at the cost of more training samples. If you have different sets of key inputs, you can make your own format file to match it. This format-file is designed to be as simple as possible to enable easy editing to match your needs.
## Target Modules ## Raw Text File Settings
By default, **Target all linear layers** is enabled. This uses peft's `all-linear` mode, which applies LoRA to every `nn.Linear` layer in the model except the output head (`lm_head`). It works for any model architecture. When using raw text files as your dataset, the text is automatically split into chunks based on your `Cutoff Length` you get a few basic options to configure them.
- `Overlap Length` is how much to overlap chunks by. Overlapping chunks helps prevent the model from learning strange mid-sentence cuts, and instead learn continual sentences that flow from earlier text.
If you uncheck it, you can manually select individual projection modules (`q_proj`, `k_proj`, `v_proj`, `o_proj`, `gate_proj`, `down_proj`, `up_proj`). Targeting fewer modules reduces VRAM usage and adapter size, but also reduces how much the model can learn. The default selection of `q_proj` + `v_proj` is the minimum for basic style/format training. - `Prefer Newline Cut Length` sets a maximum distance in characters to shift the chunk cut towards newlines. Doing this helps prevent lines from starting or ending mid-sentence, preventing the model from learning to cut off sentences randomly.
- `Hard Cut String` sets a string that indicates there must be a hard cut without overlap. This defaults to `\n\n\n`, meaning 3 newlines. No trained chunk will ever contain this string. This allows you to insert unrelated sections of text in the same text file, but still ensure the model won't be taught to randomly change the subject.
## Parameters ## Parameters
Each parameter has a description in the UI. Below is guidance on the most important choices. The basic purpose and function of each parameter is documented on-page in the WebUI, so read through them in the UI to understand your options.
That said, here's a guide to the most important parameter choices you should consider:
### VRAM ### VRAM
VRAM usage during training is roughly similar to inference with ~1000 tokens of context. If you can run the model, you can probably train LoRAs with the default settings. If you run out of VRAM, reduce `Micro Batch Size` or `Cutoff Length`. Training 4-bit quantized models uses more VRAM — set `Micro Batch Size` to `1` to compensate. - First, you must consider your VRAM availability.
- Generally, under default settings, VRAM usage for training with default parameters is very close to when generating text (with 1000+ tokens of context) (ie, if you can generate text, you can train LoRAs).
- Note: worse by default in the 4-bit monkeypatch currently. Reduce `Micro Batch Size` to `1` to restore this to expectations.
- If you have VRAM to spare, setting higher batch sizes will use more VRAM and get you better quality training in exchange.
- If you have large data, setting a higher cutoff length may be beneficial, but will cost significant VRAM. If you can spare some, set your batch size to `1` and see how high you can push your cutoff length.
- If you're low on VRAM, reducing batch size or cutoff length will of course improve that.
- Don't be afraid to just try it and see what happens. If it's too much, it will just error out, and you can lower settings and try again.
### Rank ### Rank
Higher rank = more learning capacity = larger adapter = more VRAM. Use 48 for style/format, 128256 to teach factual knowledge. - Second, you want to consider the amount of learning you want.
- For example, you may wish to just learn a dialogue format (as in the case of Alpaca) in which case setting a low `Rank` value (32 or lower) works great.
- Or, you might be training on project documentation you want the bot to understand and be able to understand questions about, in which case the higher the rank, the better.
- Generally, higher Rank = more precise learning = more total content learned = more VRAM usage while training.
### Learning Rate and Epochs ### Learning Rate and Epochs
These control how aggressively the model learns and how many times it sees the data. Higher LR + fewer epochs = fast but rough. Lower LR + more epochs = slower but higher quality. The scheduler (default: cosine) decays the LR over the course of training — see [HuggingFace docs](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules#schedules) for graphs of each option. - Third, how carefully you want it to be learned.
- In other words, how okay or not you are with the model losing unrelated understandings.
- You can control this with 3 key settings: the Learning Rate, its scheduler, and your total epochs.
- The learning rate controls how much change is made to the model by each token it sees.
- It's in scientific notation normally, so for example `3e-4` means `3 * 10^-4` which is `0.0003`. The number after `e-` controls how many `0`s are in the number.
- Higher values let training run faster, but also are more likely to corrupt prior data in the model.
- You essentially have two variables to balance: the LR, and Epochs.
- If you make LR higher, you can set Epochs equally lower to match. High LR + low epochs = very fast, low quality training.
- If you make LR low, set epochs high. Low LR + high epochs = slow but high-quality training.
- The scheduler controls change-over-time as you train - it starts high, and then goes low. This helps balance getting data in, and having decent quality, at the same time.
- You can see graphs of the different scheduler options [in the HuggingFace docs here](https://moon-ci-docs.huggingface.co/docs/transformers/pr_1/en/main_classes/optimizer_schedules#transformers.SchedulerType)
## Loss ## Loss
When you're running training, the WebUI's console window will log reports that include, among other things, a numeric value named `Loss`. It will start as a high number, and gradually get lower and lower as it goes. When you're running training, the WebUI's console window will log reports that include, among other things, a numeric value named `Loss`. It will start as a high number, and gradually get lower and lower as it goes.
Loss measures how far the model's predictions are from the training data, with `0` meaning a perfect match. It's calculated as the cross-entropy between the model's output distribution and the expected tokens. "Loss" in the world of AI training theoretically means "how close is the model to perfect", with `0` meaning "absolutely perfect". This is calculated by measuring the difference between the model outputting exactly the text you're training it to output, and what it actually outputs.
In practice, a loss of `0` means the model has overfit — it memorized the training data at the expense of its general capabilities. In practice, a good LLM should have a very complex variable range of ideas running in its artificial head, so a loss of `0` would indicate that the model has broken and forgotten how to think about anything other than what you trained it on.
Loss is a balancing game: you want it low enough that the model learns your data, but not so low that it loses general knowledge. Generally, if it goes below `1.0`, overfitting is likely and you should stop training. In some cases you may want to go as low as `0.5` (if you need very predictable outputs). Different goals have different needs, so experiment and see what works best for you. So, in effect, Loss is a balancing game: you want to get it low enough that it understands your data, but high enough that it isn't forgetting everything else. Generally, if it goes below `1.0`, it's going to start forgetting its prior memories, and you should stop training. In some cases you may prefer to take it as low as `0.5` (if you want it to be very very predictable). Different goals have different needs, so don't be afraid to experiment and see what works best for you.
Note: if you see Loss start at or suddenly jump to exactly `0`, it is likely something has gone wrong in your training process (eg model corruption). Note: if you see Loss start at or suddenly jump to exactly `0`, it is likely something has gone wrong in your training process (eg model corruption).
## Note: 4-Bit Monkeypatch
The [4-bit LoRA monkeypatch](GPTQ-models-(4-bit-mode).md#using-loras-in-4-bit-mode) works for training, but has side effects:
- VRAM usage is higher currently. You can reduce the `Micro Batch Size` to `1` to compensate.
- Models do funky things. LoRAs apply themselves, or refuse to apply, or spontaneously error out, or etc. It can be helpful to reload base model or restart the WebUI between training/usage to minimize chances of anything going haywire.
- Loading or working with multiple LoRAs at the same time doesn't currently work.
- Generally, recognize and treat the monkeypatch as the dirty temporary hack it is - it works, but isn't very stable. It will get better in time when everything is merged upstream for full official support.

View file

@ -1,15 +1,6 @@
Here you can restart the UI with new settings. Here you can restart the UI with new settings.
## Settings * **Available extensions**: shows a list of extensions available under `text-generation-webui/extensions`.
* **Toggle light/dark theme**: switches between light and dark mode.
* **Show two columns in the Notebook tab**: toggles between the two-column Default layout and the single-column Notebook layout.
* **Turn long pasted text into attachments in the Chat tab**: when enabled, long pasted text is automatically converted into file attachments.
* **Include attachments/search results from previous messages in the chat prompt**: when enabled, attachments and web search results from earlier messages are included in subsequent prompts.
## Extensions & flags
* **Available extensions**: shows a list of extensions available under `text-generation-webui/extensions` and `text-generation-webui/user_data/extensions`. Note that some of these extensions may require manually installing Python requirements through the command: `pip install -r extensions/extension_name/requirements.txt`.
* **Boolean command-line flags**: shows command-line flags of bool (true/false) type. * **Boolean command-line flags**: shows command-line flags of bool (true/false) type.
After selecting your desired flags and extensions, you can restart the UI by clicking on **Apply flags/extensions and restart**. After selecting your desired flags and extensions, you can restart the UI by clicking on **Apply flags/extensions and restart**.
@ -36,6 +27,6 @@ If you used the one-click installer, this command should be executed in the term
## Saving UI defaults ## Saving UI defaults
The **Save extensions settings to user_data/settings.yaml** button gathers the visible values in the UI and saves them to `user_data/settings.yaml` so that your settings will persist across multiple restarts of the UI. The **Save UI defaults to settings.yaml** button gathers the visible values in the UI and saves them to settings.yaml so that your settings will persist across multiple restarts of the UI.
Note that preset parameters like temperature are not individually saved, so you need to first save your preset and select it in the preset menu before saving the defaults. Note that preset parameters like temperature are not individually saved, so you need to first save your preset and select it in the preset menu before saving the defaults.

View file

@ -21,19 +21,17 @@ If you create an extension, you are welcome to host it in a GitHub repository an
|Extension|Description| |Extension|Description|
|---------|-----------| |---------|-----------|
|[openai](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/openai)| Creates an API that mimics the OpenAI API and can be used as a drop-in replacement. | |[openai](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/openai)| Creates an API that mimics the OpenAI API and can be used as a drop-in replacement. |
|[superboogav2](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/superboogav2)| Enhanced RAG extension with support for PDF, DOCX, and PPTX files. | |[multimodal](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal) | Adds multimodality support (text+images). For a detailed description see [README.md](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal/README.md) in the extension directory. |
|[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. | |[google_translate](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/google_translate)| Automatically translates inputs and outputs using Google Translate.|
|[coqui_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/coqui_tts)| Text-to-speech extension using Coqui XTTS v2. |
|[silero_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/silero_tts)| Text-to-speech extension using [Silero](https://github.com/snakers4/silero-models). When used in chat mode, responses are replaced with an audio widget. | |[silero_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/silero_tts)| Text-to-speech extension using [Silero](https://github.com/snakers4/silero-models). When used in chat mode, responses are replaced with an audio widget. |
|[whisper_stt](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/whisper_stt)| Allows you to enter your inputs in chat mode using your microphone. | |[whisper_stt](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/whisper_stt)| Allows you to enter your inputs in chat mode using your microphone. |
|[perplexity_colors](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/perplexity_colors)| Colors each token in the output text by its associated probability, as derived from the model logits. |
|[google_translate](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/google_translate)| Automatically translates inputs and outputs using Google Translate.|
|[gallery](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/gallery/)| Creates a gallery with the chat characters and their pictures. |
|[sd_api_pictures](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/sd_api_pictures)| Allows you to request pictures from the bot in chat mode, which will be generated using the AUTOMATIC1111 Stable Diffusion API. See examples [here](https://github.com/oobabooga/text-generation-webui/pull/309). | |[sd_api_pictures](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/sd_api_pictures)| Allows you to request pictures from the bot in chat mode, which will be generated using the AUTOMATIC1111 Stable Diffusion API. See examples [here](https://github.com/oobabooga/text-generation-webui/pull/309). |
|[long_replies](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/long_replies)| Forces longer replies by suppressing early newlines in the model output. |
|[ngrok](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/ngrok)| Allows you to access the web UI remotely using the ngrok reverse tunnel service (free). It's an alternative to the built-in Gradio `--share` feature. |
|[superbooga](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/superbooga)| An extension that uses ChromaDB to create an arbitrarily large pseudocontext, taking as input text files, URLs, or pasted text. Based on https://github.com/kaiokendev/superbig. |
|[character_bias](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/character_bias)| Just a very simple example that adds a hidden string at the beginning of the bot's reply in chat mode. | |[character_bias](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/character_bias)| Just a very simple example that adds a hidden string at the beginning of the bot's reply in chat mode. |
|[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. |
|[gallery](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/gallery/)| Creates a gallery with the chat characters and their pictures. |
|[superbooga](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/superbooga)| An extension that uses ChromaDB to create an arbitrarily large pseudocontext, taking as input text files, URLs, or pasted text. Based on https://github.com/kaiokendev/superbig. |
|[ngrok](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/ngrok)| Allows you to access the web UI remotely using the ngrok reverse tunnel service (free). It's an alternative to the built-in Gradio `--share` feature. |
|[perplexity_colors](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/perplexity_colors)| Colors each token in the output text by its associated probability, as derived from the model logits. |
## How to write an extension ## How to write an extension
@ -53,8 +51,8 @@ The extensions framework is based on special functions and variables that you ca
| `def history_modifier(history)` | Modifies the chat history before the text generation in chat mode begins. | | `def history_modifier(history)` | Modifies the chat history before the text generation in chat mode begins. |
| `def custom_generate_reply(...)` | Overrides the main text generation function. | | `def custom_generate_reply(...)` | Overrides the main text generation function. |
| `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. | | `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. |
| `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See the `example` extension for a template. | | `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See the `multimodal` extension for an example. |
| `def custom_tokenized_length(prompt)` | Used in conjunction with `tokenizer_modifier`, returns the length in tokens of `prompt`. See the `example` extension for a template. | | `def custom_tokenized_length(prompt)` | Used in conjunction with `tokenizer_modifier`, returns the length in tokens of `prompt`. See the `multimodal` extension for an example. |
Additionally, you can define a special `params` dictionary. In it, the `display_name` key is used to define the displayed name of the extension in the UI, and the `is_tab` key is used to define whether the extension should appear in a new tab. By default, extensions appear at the bottom of the "Text generation" tab. Additionally, you can define a special `params` dictionary. In it, the `display_name` key is used to define the displayed name of the extension in the UI, and the `is_tab` key is used to define whether the extension should appear in a new tab. By default, extensions appear at the bottom of the "Text generation" tab.
@ -188,7 +186,7 @@ def bot_prefix_modifier(string, state):
def tokenizer_modifier(state, prompt, input_ids, input_embeds): def tokenizer_modifier(state, prompt, input_ids, input_embeds):
""" """
Modifies the input ids and embeds. Modifies the input ids and embeds.
Modifies the input ids and embeds fed to the model. Used by the multimodal extension to put image embeddings in the prompt.
Only used by loaders that use the transformers library for sampling. Only used by loaders that use the transformers library for sampling.
""" """
return prompt, input_ids, input_embeds return prompt, input_ids, input_embeds

View file

@ -13,6 +13,29 @@ Source: https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/1126
This file will be automatically detected the next time you start the web UI. This file will be automatically detected the next time you start the web UI.
## DeepSpeed
`DeepSpeed ZeRO-3` is an alternative offloading strategy for full-precision (16-bit) transformers models.
With this, I have been able to load a 6b model (GPT-J 6B) with less than 6GB of VRAM. The speed of text generation is very decent and much better than what would be accomplished with `--auto-devices --gpu-memory 6`.
As far as I know, DeepSpeed is only available for Linux at the moment.
### How to use it
1. Install DeepSpeed:
```
conda install -c conda-forge mpi4py mpich
pip install -U deepspeed
```
2. Start the web UI replacing `python` with `deepspeed --num_gpus=1` and adding the `--deepspeed` flag. Example:
```
deepspeed --num_gpus=1 server.py --deepspeed --chat --model gpt-j-6B
```
## Miscellaneous info ## Miscellaneous info
### You can train LoRAs in CPU mode ### You can train LoRAs in CPU mode

View file

@ -1,52 +1,208 @@
Docker Compose is a way of installing and launching the web UI in an isolated Ubuntu image using only a few commands. Docker Compose is a way of installing and launching the web UI in an isolated Ubuntu image using only a few commands.
## Prerequisites ## Installing Docker Compose
You need Docker Compose v2.17 or higher: In order to create the image as described in the main README, you must have Docker Compose installed (2.17 or higher is recommended):
``` ```
~$ docker compose version ~$ docker compose version
Docker Compose version v2.21.0 Docker Compose version v2.21.0
``` ```
Installation instructions: https://docs.docker.com/engine/install/ The installation instructions for various Linux distributions can be found here:
For NVIDIA GPUs, you also need the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). https://docs.docker.com/engine/install/ubuntu/#install-using-the-repository
## Quick start ## Launching the image
There are four Docker variants available under `docker/`: Use these commands to launch the image:
| Directory | GPU | Notes | ```
|-----------|-----|-------| cd text-generation-webui
| `docker/nvidia` | NVIDIA | Requires NVIDIA Container Toolkit | ln -s docker/{nvidia/Dockerfile,nvidia/docker-compose.yml,.dockerignore} .
| `docker/amd` | AMD | Requires ROCm-compatible GPU | cp docker/.env.example .env
| `docker/intel` | Intel Arc | Beta support | # Edit .env and set TORCH_CUDA_ARCH_LIST based on your GPU model
| `docker/cpu` | None | CPU-only inference |
To launch (using NVIDIA as an example):
```bash
cd text-generation-webui/docker/nvidia
cp ../.env.example .env
# Optionally edit .env to customize ports, TORCH_CUDA_ARCH_LIST, etc.
docker compose up --build docker compose up --build
``` ```
The web UI will be available at `http://localhost:7860`. ## More detailed installation instructions
## User data * [Docker Compose installation instructions](#docker-compose-installation-instructions)
* [Repository with additional Docker files](#dedicated-docker-repository)
Create a `user_data/` directory next to the `docker-compose.yml` to persist your models, characters, presets, and settings between container rebuilds: By [@loeken](https://github.com/loeken).
- [Ubuntu 22.04](#ubuntu-2204)
- [0. youtube video](#0-youtube-video)
- [1. update the drivers](#1-update-the-drivers)
- [2. reboot](#2-reboot)
- [3. install docker](#3-install-docker)
- [4. docker \& container toolkit](#4-docker--container-toolkit)
- [5. clone the repo](#5-clone-the-repo)
- [6. prepare models](#6-prepare-models)
- [7. prepare .env file](#7-prepare-env-file)
- [8. startup docker container](#8-startup-docker-container)
- [Manjaro](#manjaro)
- [update the drivers](#update-the-drivers)
- [reboot](#reboot)
- [docker \& container toolkit](#docker--container-toolkit)
- [continue with ubuntu task](#continue-with-ubuntu-task)
- [Windows](#windows)
- [0. youtube video](#0-youtube-video-1)
- [1. choco package manager](#1-choco-package-manager)
- [2. install drivers/dependencies](#2-install-driversdependencies)
- [3. install wsl](#3-install-wsl)
- [4. reboot](#4-reboot)
- [5. git clone \&\& startup](#5-git-clone--startup)
- [6. prepare models](#6-prepare-models-1)
- [7. startup](#7-startup)
- [notes](#notes)
### Ubuntu 22.04
#### 0. youtube video
A video walking you through the setup can be found here:
[![oobabooga text-generation-webui setup in docker on ubuntu 22.04](https://img.youtube.com/vi/ELkKWYh8qOk/0.jpg)](https://www.youtube.com/watch?v=ELkKWYh8qOk)
#### 1. update the drivers
in the the “software updater” update drivers to the last version of the prop driver.
#### 2. reboot
to switch using to new driver
#### 3. install docker
```bash ```bash
mkdir -p user_data sudo apt update
sudo apt-get install curl
sudo mkdir -m 0755 -p /etc/apt/keyrings
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg
echo \
"deb [arch="$(dpkg --print-architecture)" signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \
"$(. /etc/os-release && echo "$VERSION_CODENAME")" stable" | \
sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
sudo apt update
sudo apt-get install docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin docker-compose -y
sudo usermod -aG docker $USER
newgrp docker
``` ```
This directory is mounted into the container at runtime. You can place a `CMD_FLAGS.txt` inside it to pass persistent flags to the web UI (e.g., `--api`). #### 4. docker & container toolkit
```bash
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg
echo "deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://nvidia.github.io/libnvidia-container/stable/ubuntu22.04/amd64 /" | \
sudo tee /etc/apt/sources.list.d/nvidia.list > /dev/null
sudo apt update
sudo apt install nvidia-docker2 nvidia-container-runtime -y
sudo systemctl restart docker
```
Models can be downloaded through the web UI's “Model” tab once it's running, and they will be saved to `user_data/models/`. #### 5. clone the repo
```
git clone https://github.com/oobabooga/text-generation-webui
cd text-generation-webui
```
#### 6. prepare models
download and place the models inside the models folder. tested with:
4bit
https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1483891617
https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1483941105
8bit:
https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1484235789
#### 7. prepare .env file
edit .env values to your needs.
```bash
cp .env.example .env
nano .env
```
#### 8. startup docker container
```bash
docker compose up --build
```
### Manjaro
manjaro/arch is similar to ubuntu just the dependency installation is more convenient
#### update the drivers
```bash
sudo mhwd -a pci nonfree 0300
```
#### reboot
```bash
reboot
```
#### docker & container toolkit
```bash
yay -S docker docker-compose buildkit gcc nvidia-docker
sudo usermod -aG docker $USER
newgrp docker
sudo systemctl restart docker # required by nvidia-container-runtime
```
#### continue with ubuntu task
continue at [5. clone the repo](#5-clone-the-repo)
### Windows
#### 0. youtube video
A video walking you through the setup can be found here:
[![oobabooga text-generation-webui setup in docker on windows 11](https://img.youtube.com/vi/ejH4w5b5kFQ/0.jpg)](https://www.youtube.com/watch?v=ejH4w5b5kFQ)
#### 1. choco package manager
install package manager (https://chocolatey.org/ )
```
Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1'))
```
#### 2. install drivers/dependencies
```
choco install nvidia-display-driver cuda git docker-desktop
```
#### 3. install wsl
wsl --install
#### 4. reboot
after reboot enter username/password in wsl
#### 5. git clone && startup
clone the repo and edit .env values to your needs.
```
cd Desktop
git clone https://github.com/oobabooga/text-generation-webui
cd text-generation-webui
COPY .env.example .env
notepad .env
```
#### 6. prepare models
download and place the models inside the models folder. tested with:
4bit https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1483891617 https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1483941105
8bit: https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1484235789
#### 7. startup
```
docker compose up
```
### notes
on older ubuntus you can manually install the docker compose plugin like this:
```
DOCKER_CONFIG=${DOCKER_CONFIG:-$HOME/.docker}
mkdir -p $DOCKER_CONFIG/cli-plugins
curl -SL https://github.com/docker/compose/releases/download/v2.17.2/docker-compose-linux-x86_64 -o $DOCKER_CONFIG/cli-plugins/docker-compose
chmod +x $DOCKER_CONFIG/cli-plugins/docker-compose
export PATH="$HOME/.docker/cli-plugins:$PATH"
```
## Dedicated docker repository ## Dedicated docker repository
An external repository maintains a docker wrapper for this project as well as several pre-configured 'one-click' `docker compose` variants. It can be found at: [Atinoda/text-generation-webui-docker](https://github.com/Atinoda/text-generation-webui-docker). An external repository maintains a docker wrapper for this project as well as several pre-configured 'one-click' `docker compose` variants (e.g., updated branches of GPTQ). It can be found at: [Atinoda/text-generation-webui-docker](https://github.com/Atinoda/text-generation-webui-docker).

View file

@ -1,25 +1,13 @@
## Using an AMD GPU in Linux ## Using an AMD GPU in Linux
Requires ROCm 6.4 to be installed. Requires ROCm SDK 5.4.2 or 5.4.3 to be installed. Some systems may also
need:
### Option 1: One-click installer
The one-click installer (`start_linux.sh`) automatically detects AMD GPUs. When prompted, select the AMD option, or set the `GPU_CHOICE` environment variable before running:
``` ```
GPU_CHOICE=B ./start_linux.sh sudo apt-get install libstdc++-12-dev
``` ```
### Option 2: Manual conda install Edit the "one_click.py" script using a text editor and un-comment and
modify the lines near the top of the script according to your setup. In
Follow the manual conda installation instructions in the README, using the AMD PyTorch command: particular, modify the `os.environ["ROCM_PATH"] = '/opt/rocm'` line to
point to your ROCm installation.
```
pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/rocm6.4
```
Then install the project requirements with the AMD requirements file:
```
pip install -r requirements/full/requirements_amd.txt
```

View file

@ -39,7 +39,7 @@ curl http://127.0.0.1:5000/v1/completions \
#### Chat completions #### Chat completions
Works best with instruction-following models. If the "instruction_template" variable is not provided, it will be guessed automatically based on the model name using the regex patterns in `user_data/models/config.yaml`. Works best with instruction-following models. If the "instruction_template" variable is not provided, it will be guessed automatically based on the model name using the regex patterns in `models/config.yaml`.
```shell ```shell
curl http://127.0.0.1:5000/v1/chat/completions \ curl http://127.0.0.1:5000/v1/chat/completions \
@ -338,35 +338,6 @@ for event in client.events():
print() print()
``` ```
#### Python parallel requests example
The API supports handling multiple requests in parallel. For ExLlamaV3, this works out of the box. For llama.cpp, you need to pass `--parallel N` to set the number of concurrent slots.
```python
import concurrent.futures
import requests
url = "http://127.0.0.1:5000/v1/chat/completions"
prompts = [
"Write a haiku about the ocean.",
"Explain quantum computing in simple terms.",
"Tell me a joke about programmers.",
]
def send_request(prompt):
response = requests.post(url, json={
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 200,
})
return response.json()["choices"][0]["message"]["content"]
with concurrent.futures.ThreadPoolExecutor() as executor:
results = list(executor.map(send_request, prompts))
for prompt, result in zip(prompts, results):
print(f"Q: {prompt}\nA: {result}\n")
```
#### Python example with API key #### Python example with API key
Replace Replace
@ -388,93 +359,83 @@ headers = {
in any of the examples above. in any of the examples above.
#### Tool/Function calling #### Tool/Function Calling Example
Use a model with tool calling support (Qwen, Mistral, GPT-OSS, etc). Tools are passed via the `tools` parameter and the prompt is automatically formatted using the model's Jinja2 template. You need to use a model with tools support. The prompt will be automatically formatted using the model's Jinja2 template.
When the model decides to call a tool, the response will have `finish_reason: "tool_calls"` and a `tool_calls` array with structured function names and arguments. You then execute the tool, send the result back as a `role: "tool"` message, and continue until the model responds with `finish_reason: "stop"`. Request:
Some models call multiple tools in parallel (Qwen, Mistral), while others call one at a time (GPT-OSS). The loop below handles both styles. ```
curl http://127.0.0.1:5000/v1/chat/completions \
```python -H "Content-Type: application/json" \
import json -d '{
import requests "messages": [
{
url = "http://127.0.0.1:5000/v1/chat/completions" "role": "system",
"content": "You are a helpful assistant."
# Define your tools },
tools = [ {
"role": "user",
"content": "What time is it currently in New York City?"
}
],
"tools": [
{ {
"type": "function", "type": "function",
"function": { "function": {
"name": "get_weather", "name": "get_current_time",
"description": "Get the current weather for a given location", "description": "Get current time in a specific timezones",
"parameters": { "parameters": {
"type": "object", "type": "object",
"required": ["timezone"],
"properties": { "properties": {
"location": {"type": "string", "description": "City name"}, "timezone": {
}, "type": "string",
"required": ["location"] "description": "IANA timezone name (e.g., America/New_York, Europe/London). Use Europe/Berlin as local timezone if no timezone provided by the user."
} }
} }
}, }
}
}
]
}'
```
Sample response:
```
{
"id": "chatcmpl-1746532051477984256",
"object": "chat.completion",
"created": 1746532051,
"model": "qwen2.5-coder-14b-instruct-q4_k_m.gguf",
"choices": [
{ {
"type": "function", "index": 0,
"function": { "finish_reason": "tool_calls",
"name": "get_time", "message": {
"description": "Get the current time in a given timezone",
"parameters": {
"type": "object",
"properties": {
"timezone": {"type": "string", "description": "IANA timezone string"},
},
"required": ["timezone"]
}
}
},
]
def execute_tool(name, arguments):
"""Replace this with your actual tool implementations."""
if name == "get_weather":
return {"temperature": 22, "condition": "sunny", "humidity": 45}
elif name == "get_time":
return {"time": "2:30 PM", "timezone": "JST"}
return {"error": f"Unknown tool: {name}"}
messages = [{"role": "user", "content": "What time is it in Tokyo and what's the weather like there?"}]
# Tool-calling loop: keep going until the model gives a final answer
for _ in range(10):
response = requests.post(url, json={"messages": messages, "tools": tools}).json()
choice = response["choices"][0]
if choice["finish_reason"] == "tool_calls":
# Add the assistant's response (with tool_calls) to history
messages.append({
"role": "assistant", "role": "assistant",
"content": choice["message"]["content"], "content": "```xml\n<function>\n{\n \"name\": \"get_current_time\",\n \"arguments\": {\n \"timezone\": \"America/New_York\"\n }\n}\n</function>\n```"
"tool_calls": choice["message"]["tool_calls"], },
}) "tool_calls": [
{
# Execute each tool and add results to history "type": "function",
for tool_call in choice["message"]["tool_calls"]: "function": {
name = tool_call["function"]["name"] "name": "get_current_time",
arguments = json.loads(tool_call["function"]["arguments"]) "arguments": "{\"timezone\": \"America/New_York\"}"
result = execute_tool(name, arguments) },
"id": "call_52ij07mh",
print(f"Tool call: {name}({arguments}) => {result}") "index": "0"
messages.append({ }
"role": "tool", ]
"tool_call_id": tool_call["id"], }
"content": json.dumps(result), ],
}) "usage": {
else: "prompt_tokens": 224,
# Final answer "completion_tokens": 38,
print(f"\nAssistant: {choice['message']['content']}") "total_tokens": 262
break }
}
``` ```
### Environment variables ### Environment variables
@ -515,45 +476,51 @@ OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111
OPENAI_API_BASE=http://127.0.0.1:5000/v1 OPENAI_API_BASE=http://127.0.0.1:5000/v1
``` ```
With the [official python openai client](https://github.com/openai/openai-python) (v1.x), the address can be set like this: With the [official python openai client](https://github.com/openai/openai-python), the address can be set like this:
```python ```python
from openai import OpenAI import openai
client = OpenAI( openai.api_key = "..."
api_key="sk-111111111111111111111111111111111111111111111111", openai.api_base = "http://127.0.0.1:5000/v1"
base_url="http://127.0.0.1:5000/v1" openai.api_version = "2023-05-15"
)
response = client.chat.completions.create(
model="x",
messages=[{"role": "user", "content": "Hello!"}]
)
print(response.choices[0].message.content)
``` ```
With the [official Node.js openai client](https://github.com/openai/openai-node) (v4.x): If using .env files to save the `OPENAI_API_BASE` and `OPENAI_API_KEY` variables, make sure the .env file is loaded before the openai module is imported:
```python
from dotenv import load_dotenv
load_dotenv() # make sure the environment variables are set before import
import openai
```
With the [official Node.js openai client](https://github.com/openai/openai-node) it is slightly more more complex because the environment variables are not used by default, so small source code changes may be required to use the environment variables, like so:
```js ```js
import OpenAI from "openai"; const openai = OpenAI(
Configuration({
const client = new OpenAI({
apiKey: process.env.OPENAI_API_KEY, apiKey: process.env.OPENAI_API_KEY,
baseURL: "http://127.0.0.1:5000/v1", basePath: process.env.OPENAI_API_BASE
}); })
);
```
const response = await client.chat.completions.create({ For apps made with the [chatgpt-api Node.js client library](https://github.com/transitive-bullshit/chatgpt-api):
model: "x",
messages: [{ role: "user", content: "Hello!" }], ```js
const api = new ChatGPTAPI({
apiKey: process.env.OPENAI_API_KEY,
apiBaseUrl: process.env.OPENAI_API_BASE
}); });
console.log(response.choices[0].message.content);
``` ```
### Embeddings (alpha) ### Embeddings (alpha)
Embeddings requires `sentence-transformers` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: `sentence-transformers/all-mpnet-base-v2` for embeddings. This produces 768 dimensional embeddings. The model is small and fast. This model and embedding size may change in the future. Embeddings requires `sentence-transformers` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: `sentence-transformers/all-mpnet-base-v2` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default `text-embedding-ada-002` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future.
| model name | dimensions | input max tokens | speed | size | Avg. performance | | model name | dimensions | input max tokens | speed | size | Avg. performance |
| ---------------------- | ---------- | ---------------- | ----- | ---- | ---------------- | | ---------------------- | ---------- | ---------------- | ----- | ---- | ---------------- |
| text-embedding-ada-002 | 1536 | 8192 | - | - | - |
| text-davinci-002 | 768 | 2046 | - | - | - |
| all-mpnet-base-v2 | 768 | 384 | 2800 | 420M | 63.3 | | all-mpnet-base-v2 | 768 | 384 | 2800 | 420M | 63.3 |
| all-MiniLM-L6-v2 | 384 | 256 | 14200 | 80M | 58.8 | | all-MiniLM-L6-v2 | 384 | 256 | 14200 | 80M | 58.8 |
@ -561,33 +528,50 @@ In short, the all-MiniLM-L6-v2 model is 5x faster, 5x smaller ram, 2x smaller st
Warning: You cannot mix embeddings from different models even if they have the same dimensions. They are not comparable. Warning: You cannot mix embeddings from different models even if they have the same dimensions. They are not comparable.
### Compatibility ### Compatibility & not so compatibility
| API endpoint | notes | Note: the table below may be obsolete.
| ------------------------- | --------------------------------------------------------------------------- |
| /v1/chat/completions | Use with instruction-following models. Supports streaming, tool calls. | | API endpoint | tested with | notes |
| /v1/completions | Text completion endpoint. | | ------------------------- | ---------------------------------- | --------------------------------------------------------------------------- |
| /v1/embeddings | Using SentenceTransformer embeddings. | | /v1/chat/completions | openai.ChatCompletion.create() | Use it with instruction following models |
| /v1/images/generations | Image generation, response_format='b64_json' only. | | /v1/embeddings | openai.Embedding.create() | Using SentenceTransformer embeddings |
| /v1/moderations | Basic support via embeddings. | | /v1/images/generations | openai.Image.create() | Bare bones, no model configuration, response_format='b64_json' only. |
| /v1/models | Lists models. Currently loaded model first. | | /v1/moderations | openai.Moderation.create() | Basic initial support via embeddings |
| /v1/models/{id} | Returns model info. | | /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options |
| /v1/audio/\* | Supported. | | /v1/models/{id} | openai.Model.get() | returns whatever you ask for |
| /v1/images/edits | Not yet supported. | | /v1/edits | openai.Edit.create() | Removed, use /v1/chat/completions instead |
| /v1/images/variations | Not yet supported. | | /v1/text_completion | openai.Completion.create() | Legacy endpoint, variable quality based on the model |
| /v1/completions | openai api completions.create | Legacy endpoint (v0.25) |
| /v1/engines/\*/embeddings | python-openai v0.25 | Legacy endpoint |
| /v1/engines/\*/generate | openai engines.generate | Legacy endpoint |
| /v1/engines | openai engines.list | Legacy Lists models |
| /v1/engines/{model_name} | openai engines.get -i {model_name} | You can use this legacy endpoint to load models via the api or command line |
| /v1/images/edits | openai.Image.create_edit() | not yet supported |
| /v1/images/variations | openai.Image.create_variation() | not yet supported |
| /v1/audio/\* | openai.Audio.\* | supported |
| /v1/files\* | openai.Files.\* | not yet supported |
| /v1/fine-tunes\* | openai.FineTune.\* | not yet supported |
| /v1/search | openai.search, engines.search | not yet supported |
#### Applications #### Applications
Almost everything needs the `OPENAI_API_KEY` and `OPENAI_API_BASE` environment variables set, but there are some exceptions. Almost everything needs the `OPENAI_API_KEY` and `OPENAI_API_BASE` environment variable set, but there are some exceptions.
Note: the table below may be obsolete.
| Compatibility | Application/Library | Website | Notes | | Compatibility | Application/Library | Website | Notes |
| ------------- | -------------------- | ------------------------------------------------------------------------------ | ----------------------------------------------------------------------------------------- | | ------------- | ---------------------- | ------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| ✅❌ | openai-python | https://github.com/openai/openai-python | Use `OpenAI(base_url="http://127.0.0.1:5000/v1")`. Only the endpoints from above work. | | ✅❌ | openai-python (v0.25+) | https://github.com/openai/openai-python | only the endpoints from above are working. OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
| ✅❌ | openai-node | https://github.com/openai/openai-node | Use `new OpenAI({baseURL: "http://127.0.0.1:5000/v1"})`. See example above. | | ✅❌ | openai-node | https://github.com/openai/openai-node | only the endpoints from above are working. environment variables don't work by default, but can be configured (see above) |
| ✅ | anse | https://github.com/anse-app/anse | API Key & URL configurable in UI, Images also work. | | ✅❌ | chatgpt-api | https://github.com/transitive-bullshit/chatgpt-api | only the endpoints from above are working. environment variables don't work by default, but can be configured (see above) |
| ✅ | shell_gpt | https://github.com/TheR1D/shell_gpt | OPENAI_API_HOST=http://127.0.0.1:5000 | | ✅ | anse | https://github.com/anse-app/anse | API Key & URL configurable in UI, Images also work |
| ✅ | gpt-shell | https://github.com/jla/gpt-shell | OPENAI_API_BASE=http://127.0.0.1:5000/v1 | | ✅ | shell_gpt | https://github.com/TheR1D/shell_gpt | OPENAI_API_HOST=http://127.0.0.1:5001 |
| ✅ | gpt-discord-bot | https://github.com/openai/gpt-discord-bot | OPENAI_API_BASE=http://127.0.0.1:5000/v1 | | ✅ | gpt-shell | https://github.com/jla/gpt-shell | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
| ✅ | OpenAI for Notepad++ | https://github.com/Krazal/nppopenai | api_url=http://127.0.0.1:5000 in the config file, or environment variables. | | ✅ | gpt-discord-bot | https://github.com/openai/gpt-discord-bot | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
| ✅ | vscode-openai | https://marketplace.visualstudio.com/items?itemName=AndrewButson.vscode-openai | OPENAI_API_BASE=http://127.0.0.1:5000/v1 | | ✅ | OpenAI for Notepad++ | https://github.com/Krazal/nppopenai | api_url=http://127.0.0.1:5001 in the config file, or environment variables |
| ✅❌ | langchain | https://github.com/hwchase17/langchain | Use `base_url="http://127.0.0.1:5000/v1"`. Results depend on model and prompt formatting. | | ✅ | vscode-openai | https://marketplace.visualstudio.com/items?itemName=AndrewButson.vscode-openai | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
| ✅❌ | langchain | https://github.com/hwchase17/langchain | OPENAI_API_BASE=http://127.0.0.1:5001/v1 even with a good 30B-4bit model the result is poor so far. It assumes zero shot python/json coding. Some model tailored prompt formatting improves results greatly. |
| ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context |
| ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
| ❌ | guidance | https://github.com/microsoft/guidance | logit_bias and logprobs not yet supported |

View file

@ -1,159 +0,0 @@
## Supported models
The following models are supported:
- Qwen 3.5
- GPT-OSS
- Mistral Small / Devstral
- DeepSeek V3
- Kimi-K2
- MiniMax-M2.5
- GLM-5
- Llama 4
Other models that output tool calls as JSON (inside XML tags, code blocks, or plain JSON) are also supported through a generic fallback parser.
## Tool calling in the UI
### 1. Load a model with tool-calling support
Load a model with tool-calling support from the Model tab.
### 2. Select tools
In the chat sidebar, check the tools you want the model to use:
- **web_search** -- Search the web using DuckDuckGo.
- **fetch_webpage** -- Fetch the content of a URL.
- **calculate** -- Evaluate math expressions.
- **get_datetime** -- Get the current date and time.
- **roll_dice** -- Roll dice.
### 3. Chat
Send a message as usual. When the model decides it needs a tool, it will call it automatically. You will see each tool call and its result in a collapsible accordion inside the chat message.
The model may call multiple tools in sequence before giving its final answer.
## Writing custom tools
Each tool is a single `.py` file in `user_data/tools/`. It needs two things:
1. A `tool` dictionary that describes the function (name, description, parameters).
2. An `execute(arguments)` function that runs it and returns the result.
Here is a minimal example (`user_data/tools/get_datetime.py`):
```python
from datetime import datetime
tool = {
"type": "function",
"function": {
"name": "get_datetime",
"description": "Get the current date and time.",
"parameters": {
"type": "object",
"properties": {},
}
}
}
def execute(arguments):
now = datetime.now()
return {"date": now.strftime("%Y-%m-%d"), "time": now.strftime("%I:%M %p")}
```
An example with parameters (`user_data/tools/roll_dice.py`):
```python
import random
tool = {
"type": "function",
"function": {
"name": "roll_dice",
"description": "Roll one or more dice with the specified number of sides.",
"parameters": {
"type": "object",
"properties": {
"count": {"type": "integer", "description": "Number of dice to roll.", "default": 1},
"sides": {"type": "integer", "description": "Number of sides per die.", "default": 20},
},
}
}
}
def execute(arguments):
count = max(1, min(arguments.get("count", 1), 1000))
sides = max(2, min(arguments.get("sides", 20), 1000))
rolls = [random.randint(1, sides) for _ in range(count)]
return {"rolls": rolls, "total": sum(rolls)}
```
You can open the built-in tools in `user_data/tools/` for more examples.
## Tool calling over the API
Tool calling over the API follows the [OpenAI API](https://platform.openai.com/docs/guides/function-calling) convention. Define your tools, send them with your messages, and handle tool calls in a loop until the model gives a final answer.
```python
import json
import requests
url = "http://127.0.0.1:5000/v1/chat/completions"
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather for a given location.",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "City name"},
},
"required": ["location"]
}
}
}
]
def execute_tool(name, arguments):
if name == "get_weather":
return {"temperature": "14°C", "condition": "partly cloudy"}
return {"error": f"Unknown tool: {name}"}
messages = [{"role": "user", "content": "What's the weather like in Paris?"}]
for _ in range(10):
response = requests.post(url, json={"messages": messages, "tools": tools}).json()
choice = response["choices"][0]
if choice["finish_reason"] == "tool_calls":
messages.append({
"role": "assistant",
"content": choice["message"]["content"],
"tool_calls": choice["message"]["tool_calls"],
})
for tool_call in choice["message"]["tool_calls"]:
name = tool_call["function"]["name"]
arguments = json.loads(tool_call["function"]["arguments"])
result = execute_tool(name, arguments)
print(f"Tool call: {name}({arguments}) => {result}")
messages.append({
"role": "tool",
"tool_call_id": tool_call["id"],
"content": json.dumps(result),
})
else:
print(f"\nAssistant: {choice['message']['content']}")
break
```

View file

@ -1,17 +1,20 @@
## What Works ## What Works
| Loader | Loading LoRAs | Training LoRAs | Multimodal | Perplexity evaluation | | Loader | Loading 1 LoRA | Loading 2 or more LoRAs | Training LoRAs | Multimodal extension | Perplexity evaluation |
|----------------|---------------|----------------|------------|-----------------------| |----------------|----------------|-------------------------|----------------|----------------------|-----------------------|
| llama.cpp | ❌ | ❌ | ✅\* | ❌ | | Transformers | ✅ | ✅\*\* | ✅\* | ✅ | ✅ |
| Transformers | ✅ | ✅ | ✅\*\* | ✅ | | llama.cpp | ❌ | ❌ | ❌ | ❌ | use llamacpp_HF |
| ExLlamav3_HF | ❌ | ❌ | ❌ | ✅ | | llamacpp_HF | ❌ | ❌ | ❌ | ❌ | ✅ |
| ExLlamav3 | ❌ | ❌ | ✅ | ❌ | | ExLlamav2_HF | ✅ | ✅ | ❌ | ❌ | ✅ |
| TensorRT-LLM | ❌ | ❌ | ❌ | ❌ | | ExLlamav2 | ✅ | ✅ | ❌ | ❌ | use ExLlamav2_HF |
| AutoGPTQ | ✅ | ❌ | ❌ | ✅ | ✅ |
| AutoAWQ | ? | ❌ | ? | ? | ✅ |
| HQQ | ? | ? | ? | ? | ✅ |
❌ = not supported ❌ = not implemented
✅ = supported ✅ = implemented
\* Via the `mmproj` parameter (multimodal projector file). \* Training LoRAs with GPTQ models also works with the Transformers loader. Make sure to check "auto-devices" and "disable_exllama" before loading the model.
\*\* Via the `send_pictures` extension. \*\* Multi-LoRA in PEFT is tricky and the current implementation does not work reliably in all cases.

View file

@ -24,8 +24,6 @@ from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, RequestException, Timeout from requests.exceptions import ConnectionError, RequestException, Timeout
from tqdm.contrib.concurrent import thread_map from tqdm.contrib.concurrent import thread_map
from modules.paths import resolve_user_data_dir
base = os.environ.get("HF_ENDPOINT") or "https://huggingface.co" base = os.environ.get("HF_ENDPOINT") or "https://huggingface.co"
@ -184,13 +182,11 @@ class ModelDownloader:
is_llamacpp = has_gguf and specific_file is not None is_llamacpp = has_gguf and specific_file is not None
return links, sha256, is_lora, is_llamacpp, file_sizes return links, sha256, is_lora, is_llamacpp, file_sizes
def get_output_folder(self, model, branch, is_lora, is_llamacpp=False, model_dir=None, user_data_dir=None): def get_output_folder(self, model, branch, is_lora, is_llamacpp=False, model_dir=None):
if model_dir: if model_dir:
base_folder = model_dir base_folder = model_dir
else: else:
if user_data_dir is None: base_folder = 'user_data/models' if not is_lora else 'user_data/loras'
user_data_dir = resolve_user_data_dir()
base_folder = str(user_data_dir / 'models') if not is_lora else str(user_data_dir / 'loras')
# If the model is of type GGUF, save directly in the base_folder # If the model is of type GGUF, save directly in the base_folder
if is_llamacpp: if is_llamacpp:
@ -396,8 +392,7 @@ if __name__ == '__main__':
parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).') parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).')
parser.add_argument('--exclude-pattern', type=str, default=None, help='Regex pattern to exclude files from download.') parser.add_argument('--exclude-pattern', type=str, default=None, help='Regex pattern to exclude files from download.')
parser.add_argument('--output', type=str, default=None, help='Save the model files to this folder.') parser.add_argument('--output', type=str, default=None, help='Save the model files to this folder.')
parser.add_argument('--model-dir', type=str, default=None, help='Save the model files to a subfolder of this folder instead of the default one (user_data/models).') parser.add_argument('--model-dir', type=str, default=None, help='Save the model files to a subfolder of this folder instead of the default one (text-generation-webui/user_data/models).')
parser.add_argument('--user-data-dir', type=str, default=None, help='Path to the user data directory. Overrides auto-detection.')
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.') parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.') parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
parser.add_argument('--max-retries', type=int, default=7, help='Max retries count when get error in download time.') parser.add_argument('--max-retries', type=int, default=7, help='Max retries count when get error in download time.')
@ -413,26 +408,6 @@ if __name__ == '__main__':
sys.exit() sys.exit()
downloader = ModelDownloader(max_retries=args.max_retries) downloader = ModelDownloader(max_retries=args.max_retries)
# Handle direct file URLs (e.g. https://huggingface.co/org/repo/resolve/branch/file.gguf)
if '/resolve/' in model:
url = model if model.startswith('http') else f'{base}/{model}'
url = url.split('?')[0]
filename = url.split('/')[-1]
if args.output:
output_folder = Path(args.output)
elif args.model_dir:
output_folder = Path(args.model_dir)
else:
user_data_dir = Path(args.user_data_dir) if args.user_data_dir else resolve_user_data_dir()
output_folder = user_data_dir / 'models'
output_folder.mkdir(parents=True, exist_ok=True)
print(f"Downloading {filename} to {output_folder}")
downloader.get_single_file(url, output_folder, start_from_scratch=args.clean)
sys.exit()
# Clean up the model/branch names # Clean up the model/branch names
try: try:
model, branch = downloader.sanitize_model_and_branch_names(model, branch) model, branch = downloader.sanitize_model_and_branch_names(model, branch)
@ -446,11 +421,10 @@ if __name__ == '__main__':
) )
# Get the output folder # Get the output folder
user_data_dir = Path(args.user_data_dir) if args.user_data_dir else None
if args.output: if args.output:
output_folder = Path(args.output) output_folder = Path(args.output)
else: else:
output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp, model_dir=args.model_dir, user_data_dir=user_data_dir) output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp, model_dir=args.model_dir)
if args.check: if args.check:
# Check previously downloaded files # Check previously downloaded files

View file

@ -0,0 +1,92 @@
# Training_PRO
This is an expanded and reworked Training tab
Maintained by FP
[![ko-fi](https://ko-fi.com/img/githubbutton_sm.svg)](https://ko-fi.com/Q5Q5MOB4M)
Repo home:
https://github.com/FartyPants/Training_PRO
In general the repo above is ahead of the extension included in text WebUi.
## News
- NEFtune: add noise to help with generalization
- Loss Graph in interface.
- Supports Mistral training
- some roundabout around pytorch and transformers version desync
![image](https://github.com/FartyPants/Training_PRO/assets/23346289/e389ec69-d7ad-4922-9ad9-865625997479)
## Features/Changes
- Chunking: precise raw text slicer (PRTS) uses sentence slicing and making sure things are clean on all ends
- overlap chunking - this special overlapping will make additional overlap block based on logical rules (aka no overlap block on hard cut)
- custom scheduler (follow the code to make your own) In LR Scheduler select FP_low_epoch_annealing - this scheduler will keep the LR constant for first epoch then use cosine for the rest - this part would be best to spawn into a new py file
- saves graph png file at the end with learning rate and loss per epoch
- adding EOS to each block or to hard cut only
- automatically lowers gradient accumulation if you go overboard and set gradient accumulation that will be higher than actual data - transformers would then throw error (or they used to, not sure if still true) but in any way, it will fix bad data
- turn BOS on and OFF
- target selector
- DEMENTOR LEARNING (experimental) Deep Memorization Enforcement Through Overlapping and Repetition. This is an experiment for long-text learning using low epochs (basically use 1 epoch with constant LR or 2 epochs with FP_low_epoch_annealing LR scheduler)
- Getting rid of micro batch size/batch size confusion. Now there is True Batch Size and Gradient accumulation slider, consisten with all the other training out there
- Ability to save Checkpoint during training with a button
- Ability to change Stop Loss during training
- different modes of checkpoint auto saving
- Function to Check Dataset and suggest parameters such as warmup and checkpoint save frequency before training
- Graph Training Loss in interface
- more custom schedulers
### Notes:
This uses it's own chunking code for raw text based on sentence splitting. This will avoid weird cuts in the chunks and each chunk should now start with sentence and end on some sentence. It works hand in hand with Hard Cut. A propper use is to structure your text into logical blocks (ideas) separated by three \n then use three \n in hard cut. This way each chunk will contain only one flow of ideas and not derail in the thoughts. And Overlapping code will create overlapped blocks on sentence basis too, but not cross hard cut, thus not cross different ideas either. Does it make any sense? No? Hmmmm...
### Custom schedulers
A bunch of custom (combination) schedulers are added to the LR schedule. These are based on my own experiments
**FP_low_epoch_annealing**
Uses constant LR (with warmup) for 1 epoch only. The rest of the epoch(s) is cosine annealing. So 10 epochs - 1 will be constant 9 will be nose dive down. However a typical usage would be 2 epochs (hence low epoch in name). 1st is constant, the second is annealing. Simple. I use it 90% of time.
**FP_half_time_annealing**
Like the low epoch, but now the total number of steps is divided by 2. First half is constant, second half is annealing. So 10 epochs - 5 will be constant, 5 will be cosine nose down.
**FP_raise_fall_creative**
This is a sine raise till half of the total steps then cosine fall the rest. (Or you may think of the curve as sine in its entirety. The most learning is done in the hump, in the middle. The warmup entry has no effect, since sine is automatically warm up.
The idea is to start very mildly as not to overfit with the first blocks of dataset. It seems to broaden the scope of the model making it less strict for tight dataset.
### Targets
Normal LORA is q, v and that's what you should use. You can use (q k v o) or (q k v) and it will give you a lot more trainable parameters. The benefit is that you can keep rank lower and still attain the same coherency as q v with high rank. Guanaco has been trained with QLORA and q k v o for example and they swear by it.
### DEMENTOR LEARNING (experimental) Deep Memorization Enforcement Through Overlapping and Repetition
This is and experimental chunking to train long-form text in low number of epochs (basically 1) with sliding repetition. The depth of learning directly depends on the cutoff_length. Increasing cutoff length will also increase number of blocks created from long-form text (which is contrary to normal training). It is based on my own wild experiments.
### Getting rid of batch size and micro batch size
Keeping consistency with everyone else.
Listen, There is only ONE batch size - the True batch size (called previously micro-batch size in WebUI) - this is how many blocks are processed at once (during a single step). It eats GPU, but it really helps with the quality training (in fact the ideal batch size would be the same as number of blocks - which is unrealistic) - so the idea is to cram as much True Batch Size before your GPU blows with OOM. On 24GB this is about 10 for 13b (loaded with 4-bit)
So no micro batch size - it is now called True Batch Size, because that's what it is.
The other thing is Gradient Accumulation - this is an emulation of the above Batch size - a virtual batch size, if you will. If your GPU can't handle real batch size then you may fake it using Gradient Accumulation. This will accumulate the gradients over so many steps defined here and then update the weights at the end without increase in GPU.
Gradient accumulation is like a virtual Batch size multiplier without the GPU penalty.
If your batch size is 4 and your gradient accumulation is 2 then it sort of behaves as if we have batch size 8. *Sort of* because Batch size of 4 and GA of 2 is NOT the same as batch size of 2 and GA of 4. (It produces different weights - hence it's not an equivalent). The idea is that if you don't have GPU - using GA to extend batch size is the next best thing (good enough) since you have no other choice.
If all you can afford is 1 batch size, then increasing GA will likely make the learning better in some range of GA (it's not always more is better).
However - GA is not some golden goose. As said, it isn't the same as batch size. In fact GA may worsen your learning as well.
I would suggest a series of experiment where you would put batch size as high as possible without OOM, set GA 1, then repeat training while increasing the GA (2, 4...), and see how the model changes. It's likely that it would follow some sort of curve where GA will seem to help before it will make it worse. Some people believe that if you can squeeze 6 BATCH Size, then you should not bother with GA at all... YMMW
High Batch Size vs High GA would also likely produce different results in terms of learning words vs style. How? Hmmmm... good question.
One optical "benefit" of GA is that the loss will fluctuate less (because of all the gradient accumulation, which works as a form of noise smoothing as well).

View file

@ -0,0 +1,433 @@
from functools import partial
import torch
import transformers
import math
from torch.optim.lr_scheduler import LambdaLR
from peft import (
PeftModel,
)
RED = "\033[91m"
YELLOW = "\033[93m"
GREEN = "\033[92m"
RESET = "\033[0m"
last_print_label = ''
custom_scheduler_params = {'trigger_loss': 0.0, 'ramp_down_ratio':1.0, 'current_loss': 0.0,'dynamic_scheduler_stop': False, 'calc_ramp_down_at_step': 0, 'calc_num_training_steps': 0}
def custom_scheduler_global_update(current_loss: float):
custom_scheduler_params.update({'current_loss': current_loss})
def custom_scheduler_global_setup(trigger_loss: float, ramp_down_ratio: float):
custom_scheduler_params.update({'trigger_loss': trigger_loss})
custom_scheduler_params.update({'ramp_down_ratio': ramp_down_ratio})
# calculates the total num steps after trigger
custom_scheduler_params.update({'calc_num_training_steps': 0})
#calculates steps when the ramp_down trigger occured
custom_scheduler_params.update({'calc_ramp_down_at_step': 0})
# triggers scheduler stopping after it reached calc_num_training_steps
custom_scheduler_params.update({'dynamic_scheduler_stop': False})
# hold constant to the half of epochs then cosine down to 0
def _get_fp_half_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_firstepoch_steps: int):
global last_print_label
print_label = ''
half_steps = num_training_steps//2
num_warmup_steps = min(num_warmup_steps,half_steps)
if current_step < num_warmup_steps:
print_label = 'Scheduler: Warmup'
elif current_step < half_steps:
print_label = 'Scheduler: Hold'
else:
print_label = 'Scheduler: Annealing'
if print_label != last_print_label:
print(print_label)
last_print_label = print_label
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
if current_step < half_steps:
return 1.0
progress = float(current_step - half_steps) / float(max(1, num_training_steps - half_steps))
num_cycles = 0.5
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
# raise up in cosine, then fall back in cosine
def _get_fp_cosine_raise_and_fall_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_firstepoch_steps: int):
global last_print_label
print_label = ''
half_steps = num_training_steps//2
#num_warmup_steps = min(num_warmup_steps,half_steps)
if current_step < half_steps:
print_label = 'Scheduler: Raise'
else:
print_label = 'Scheduler: Fall'
if print_label != last_print_label:
print(print_label)
last_print_label = print_label
# linear
# return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - half_steps) / float(max(1, num_training_steps - half_steps))
num_cycles = 0.5
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
# constant to the first epochs then cosine down to 0 over the rest epochs
def _get_fp_cosine_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_firstepoch_steps: int):
global last_print_label
print_label = ''
num_warmup_steps = min(num_warmup_steps,num_firstepoch_steps)
if current_step < num_warmup_steps:
print_label = 'Scheduler: Warmup'
elif current_step < num_firstepoch_steps:
print_label = 'Scheduler: Hold'
else:
print_label = 'Scheduler: Annealing'
if print_label != last_print_label:
print(print_label)
last_print_label = print_label
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
if current_step < num_firstepoch_steps:
return 1.0
progress = float(current_step - num_firstepoch_steps) / float(max(1, num_training_steps - num_firstepoch_steps))
num_cycles = 0.5
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
# halve lr each epoch
def _get_fp_cdrop_rate_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_firstepoch_steps: int):
global last_print_label
print_label = ''
num_warmup_steps = min(num_warmup_steps, num_firstepoch_steps)
current_epoch = (current_step // num_firstepoch_steps) + 1
if current_step < num_warmup_steps:
print_label = 'Scheduler: Warmup'
elif current_step < num_firstepoch_steps:
print_label = 'Scheduler: Hold'
else:
print_label = 'Scheduler: Drop Rate'
if print_label != last_print_label:
print(print_label)
last_print_label = print_label
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
if current_step < num_firstepoch_steps:
return 1.0
# Compute the learning rate for the annealing phase
learning_rate = 1.0 / float(2 ** (current_epoch - 1))
return learning_rate
# epoch decay: 1/(1 + decay * epoch)
def custom_cosine_scheduler_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_firstepoch_steps, last_epoch=-1):
"""
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_lambda = partial(
_get_fp_cosine_schedule_with_warmup_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_firstepoch_steps = num_firstepoch_steps,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def custom_half_scheduler_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_firstepoch_steps, last_epoch=-1):
"""
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_lambda = partial(
_get_fp_half_schedule_with_warmup_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_firstepoch_steps = num_firstepoch_steps,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def custom_raise_fall_scheduler_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_firstepoch_steps, last_epoch=-1):
"""
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_lambda = partial(
_get_fp_cosine_raise_and_fall_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_firstepoch_steps = num_firstepoch_steps,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def neftune_forward(self, input: torch.Tensor):
"""
Implements the NEFTune forward pass for the model. Note this works only for
torch.nn.Embedding layers. This method is slightly adapted from the original source code
that can be found here: https://github.com/neelsjain/NEFTune
Args:
input (`torch.Tensor`):
The input tensor to the model.
noise_alpha (`float`):
The noise alpha value to use for the NEFTune forward pass.
"""
embeddings = torch.nn.functional.embedding(
input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse
)
if self.training:
# Add noise to the embeddings
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
mag_norm = self.neftune_noise_alpha / torch.sqrt(dims)
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(-mag_norm, mag_norm)
return embeddings
class FPNEFtuneTrainer(transformers.Trainer):
def __init__(self,neftune_noise_alpha:float = 0.0, model = None, *args, **kwargs):
self.neftune_noise_alpha = neftune_noise_alpha
if self.neftune_noise_alpha > 0.0:
model = self._activate_neftune(model)
super().__init__(model = model, *args, **kwargs)
def _activate_neftune(self, model):
r"""
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: https://arxiv.org/abs/2310.05914
"""
print(f"Activating {RED}NEFtune{RESET} with scale: {self.neftune_noise_alpha}")
if isinstance(model, transformers.PreTrainedModel):
embeddings = model.get_input_embeddings()
elif isinstance(model, PeftModel):
embeddings = model.base_model.get_input_embeddings()
embeddings.neftune_noise_alpha = self.neftune_noise_alpha
old_forward = embeddings.forward
# This hack seems to be needed to properly use a custom forward pass
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
bound_method = neftune_forward.__get__(embeddings, embeddings.__class__)
setattr(embeddings, "forward", bound_method)
# embeddings.forward = neftune_forward
embeddings._trl_old_forward = old_forward
return model
def train(self, *args, **kwargs):
output = super().train(*args, **kwargs)
# After training we make sure to retrieve back the original forward pass method
# for the embedding layer
if self.neftune_noise_alpha is not None:
if isinstance(self.model, transformers.PreTrainedModel):
embeddings = self.model.get_input_embeddings()
elif isinstance(self.model, PeftModel):
embeddings = self.model.base_model.get_input_embeddings()
if hasattr(embeddings, "_trl_old_forward"):
embeddings.forward = embeddings._trl_old_forward
del embeddings._trl_old_forward
del embeddings.neftune_noise_alpha
return output
class FPSchedulerTrainer(transformers.Trainer):
def __init__(self,neftune_noise_alpha:float = 0.0, model = None, *args, **kwargs):
self.neftune_noise_alpha = neftune_noise_alpha
if self.neftune_noise_alpha > 0.0:
model = self._activate_neftune(model)
super().__init__(model = model, *args, **kwargs)
def _activate_neftune(self, model):
r"""
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: https://arxiv.org/abs/2310.05914
"""
print(f"Activating {RED}NEFtune{RESET} with scale: {self.neftune_noise_alpha}")
if isinstance(model, transformers.PreTrainedModel):
embeddings = model.get_input_embeddings()
elif isinstance(model, PeftModel):
embeddings = model.base_model.get_input_embeddings()
embeddings.neftune_noise_alpha = self.neftune_noise_alpha
old_forward = embeddings.forward
# This hack seems to be needed to properly use a custom forward pass
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
bound_method = neftune_forward.__get__(embeddings, embeddings.__class__)
setattr(embeddings, "forward", bound_method)
# embeddings.forward = neftune_forward
embeddings._trl_old_forward = old_forward
return model
def train(self, *args, **kwargs):
output = super().train(*args, **kwargs)
# After training we make sure to retrieve back the original forward pass method
# for the embedding layer
if self.neftune_noise_alpha is not None:
if isinstance(self.model, transformers.PreTrainedModel):
embeddings = self.model.get_input_embeddings()
elif isinstance(self.model, PeftModel):
embeddings = self.model.base_model.get_input_embeddings()
if hasattr(embeddings, "_trl_old_forward"):
embeddings.forward = embeddings._trl_old_forward
del embeddings._trl_old_forward
del embeddings.neftune_noise_alpha
return output
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
#Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or passed as an argument.
num_train_epochs = self.args.num_train_epochs
num_warmup_steps=self.args.get_warmup_steps(num_training_steps)
num_firstepoch_steps = math.ceil(num_training_steps/num_train_epochs)
num_warmup_acc = num_warmup_steps*self.args.gradient_accumulation_steps
num_firstepoch_steps_acc = num_firstepoch_steps*self.args.gradient_accumulation_steps
num_training_steps_acc = num_training_steps*self.args.gradient_accumulation_steps
custom_scheduler_params.update({'dynamic_scheduler_stop': False})
print (f"Warm-up steps aligned to Gradient accumulation ({self.args.gradient_accumulation_steps}) = {num_warmup_acc} actual warmup steps")
if self.args.lr_scheduler_type == 'cosine':
num_warmup_acc_min = min(num_warmup_acc, num_firstepoch_steps_acc)
if num_warmup_acc>num_firstepoch_steps_acc:
print(f"\033[1;31;1mWARNING: The number of warmup steps is set too high! It will be clamped to 1 epoch, essentially going from warmup to annealing.\033[0;37;0m")
print (f"FP Scheduler Warmup: 0-[{num_warmup_acc_min}], Hold [{num_warmup_acc_min}]-{num_firstepoch_steps_acc}, Annealing {num_firstepoch_steps_acc}-{num_training_steps_acc}")
else:
print (f"FP Scheduler Warmup: 0-{num_warmup_acc_min}, Hold {num_warmup_acc_min}-{num_firstepoch_steps_acc}, Annealing {num_firstepoch_steps_acc}-{num_training_steps_acc}")
self.lr_scheduler = custom_cosine_scheduler_with_warmup(
optimizer=self.optimizer if optimizer is None else optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_firstepoch_steps = num_firstepoch_steps,
)
self._created_lr_scheduler = True
return self.lr_scheduler
elif self.args.lr_scheduler_type == 'constant':
half_step_acc = num_training_steps_acc//2
num_warmup_acc_min = min(num_warmup_acc, half_step_acc)
if num_warmup_acc>half_step_acc:
print(f"\033[1;31;1mWARNING: The number of warmup steps is set too high! It will be clamped to half of all epochs, essentially going from warmup to annealing in the middle.\033[0;37;0m")
print (f"FP Scheduler Warmup: 0-[{num_warmup_acc_min}], Hold [{num_warmup_acc_min}]-{half_step_acc}, Annealing {half_step_acc}-{num_training_steps_acc}")
else:
print (f"FP Scheduler Warmup: 0-{num_warmup_acc_min}, Hold {num_warmup_acc_min}-{half_step_acc}, Annealing {half_step_acc}-{num_training_steps_acc}")
self.lr_scheduler = custom_half_scheduler_with_warmup(
optimizer=self.optimizer if optimizer is None else optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_firstepoch_steps = num_firstepoch_steps,
)
self._created_lr_scheduler = True
return self.lr_scheduler
elif self.args.lr_scheduler_type == 'constant_with_warmup':
half_step_acc = num_training_steps_acc//2
if num_warmup_steps>0:
print(f"Warmup doesn't apply to this scheduler [Raise-Fall]")
print (f"Scheduler Raise: 0-{half_step_acc}, Fall {half_step_acc}-{num_training_steps_acc}")
self.lr_scheduler = custom_raise_fall_scheduler_with_warmup(
optimizer=self.optimizer if optimizer is None else optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_firstepoch_steps = num_firstepoch_steps,
)
self._created_lr_scheduler = True
return self.lr_scheduler
else:
return super().create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)

View file

@ -0,0 +1,62 @@
import os
import json
def create_graph(lora_path, lora_name):
try:
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
peft_model_path = f'{lora_path}/training_graph.json'
image_model_path = f'{lora_path}/training_graph.png'
# Check if the JSON file exists
if os.path.exists(peft_model_path):
# Load data from JSON file
with open(peft_model_path, 'r') as file:
data = json.load(file)
# Extract x, y1, and y2 values
x = [item['epoch'] for item in data]
y1 = [item['learning_rate'] for item in data]
y2 = [item['loss'] for item in data]
# Create the line chart
fig, ax1 = plt.subplots(figsize=(10, 6))
# Plot y1 (learning rate) on the first y-axis
ax1.plot(x, y1, 'b-', label='Learning Rate')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Learning Rate', color='b')
ax1.tick_params('y', colors='b')
# Create a second y-axis
ax2 = ax1.twinx()
# Plot y2 (loss) on the second y-axis
ax2.plot(x, y2, 'r-', label='Loss')
ax2.set_ylabel('Loss', color='r')
ax2.tick_params('y', colors='r')
# Set the y-axis formatter to display numbers in scientific notation
ax1.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
ax1.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
# Add grid
ax1.grid(True)
# Combine the legends for both plots
lines, labels = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax2.legend(lines + lines2, labels + labels2, loc='best')
# Set the title
plt.title(f'{lora_name} LR and Loss vs Epoch')
# Save the chart as an image
plt.savefig(image_model_path)
print(f"Graph saved in {image_model_path}")
else:
print(f"File 'training_graph.json' does not exist in the {lora_path}")
except ImportError:
print("matplotlib is not installed. Please install matplotlib to create PNG graphs")

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,368 @@
import os
from modules import shared, utils
from pathlib import Path
import requests
import tqdm
import json
'''
def get_gpu_memory_usage(rank):
return {
'total': round(torch.cuda.get_device_properties(rank).total_memory / (1024**3), 2),
'max': round(torch.cuda.max_memory_allocated(rank) / (1024**3), 2),
'reserved': round(torch.cuda.memory_reserved(rank) / (1024**3), 2),
'allocated': round(torch.cuda.memory_allocated(rank) / (1024**3), 2)
}
'''
def list_subfoldersByTime(directory):
if not directory.endswith('/'):
directory += '/'
subfolders = []
subfolders.append('None')
path = directory
name_list = os.listdir(path)
full_list = [os.path.join(path,i) for i in name_list]
time_sorted_list = sorted(full_list, key=os.path.getmtime,reverse=True)
for entry in time_sorted_list:
if os.path.isdir(entry):
entry_str = f"{entry}" # Convert entry to a string
full_path = entry_str
entry_str = entry_str.replace('\\','/')
entry_str = entry_str.replace(f"{directory}", "") # Remove directory part
subfolders.append(entry_str)
return subfolders
def get_available_loras_local(_sortedByTime):
model_dir = shared.args.lora_dir # Update with the appropriate directory path
subfolders = []
if _sortedByTime:
subfolders = list_subfoldersByTime(model_dir)
else:
subfolders = utils.get_available_loras()
return subfolders
# FPHAM SPLIT BY SENTENCE BLOCK ===============
def split_sentences(text: str, cutoff_len: int):
sentences = []
sentence = ''
delimiters = ['. ', '? ', '! ', '... ', '.\n', '?\n', '!\n','...\n','</s>','<//>']
abbreviations = ['Mr. ', 'Mrs. ', 'Dr. ', 'Ms. ', 'St. ', 'Prof. ', 'Jr. ', 'Ltd. ', 'Capt. ', 'Col. ', 'Gen. ', 'Ave. ', 'Blvd. ', 'Co. ', 'Corp. ', 'Dept. ', 'Est. ', 'Gov. ', 'Inc. ', 'Ph.D. ', 'Univ. ']
errors = 0
max_cut = cutoff_len-1
prev_char = ''
for char in text:
sentence += char
if (any(sentence.endswith(delimiter) for delimiter in delimiters) and
not (prev_char.isupper() and len(sentence) >= 3 and sentence[-3] != ' ') and
not any(sentence.endswith(abbreviation) for abbreviation in abbreviations)):
tokens = shared.tokenizer.encode(sentence)
if len(tokens) > max_cut:
tokens = tokens[:max_cut]
sentence = shared.tokenizer.decode(tokens, skip_special_tokens=True)
errors = errors + 1
sentences.append({'text': sentence, 'size': len(tokens)})
sentence = ''
prev_char = char
if sentence:
tokens = shared.tokenizer.encode(sentence)
if len(tokens) > max_cut:
tokens = tokens[:max_cut]
sentence = shared.tokenizer.decode(tokens, skip_special_tokens=True)
errors = errors + 1
sentences.append({'text': sentence, 'size': len(tokens)})
if errors > 0:
print(f"Trimmed sentences beyond Cutoff Length: {errors}")
return sentences
# The goal of following code is to create blocks of text + overlapping blocks while:
# respects sentence boundaries
# always uses all the text
# hard cut defined by hard_cut_string or </s> will always end at the end of data block
# no overlapping blocks will be created across hard cut or across </s> token
def precise_cut(text: str, overlap: bool, min_chars_cut: int, eos_to_hc: bool, cutoff_len: int, hard_cut_string: str, debug_slicer:bool):
EOSX_str = '<//>' #hardcut placeholder
EOS_str = '</s>'
print("Precise raw text slicer: ON")
cut_string = hard_cut_string.replace('\\n', '\n')
text = text.replace(cut_string, EOSX_str)
sentences = split_sentences(text, cutoff_len)
print(f"Sentences: {len(sentences)}")
sentencelist = []
currentSentence = ''
totalLength = 0
max_cut = cutoff_len-1
half_cut = cutoff_len//2
halfcut_length = 0
edgeindex = []
half_index = 0
for index, item in enumerate(sentences):
if halfcut_length+ item['size'] < half_cut:
halfcut_length += item['size']
half_index = index
else:
edgeindex.append(half_index)
halfcut_length = -2 * max_cut
if totalLength + item['size'] < max_cut and not currentSentence.endswith(EOSX_str):
currentSentence += item['text']
totalLength += item['size']
else:
if len(currentSentence.strip()) > min_chars_cut:
sentencelist.append(currentSentence.strip())
currentSentence = item['text']
totalLength = item['size']
halfcut_length = item['size']
if len(currentSentence.strip()) > min_chars_cut:
sentencelist.append(currentSentence.strip())
unique_blocks = len(sentencelist)
print(f"Text Blocks: {unique_blocks}")
#overlap strategies:
# don't overlap across HARD CUT (EOSX)
if overlap:
for edge_idx in edgeindex:
currentSentence = ''
totalLength = 0
for item in sentences[edge_idx:]:
if totalLength + item['size'] < max_cut:
currentSentence += item['text']
totalLength += item['size']
else:
#if by chance EOSX is at the end then it's acceptable
if currentSentence.endswith(EOSX_str) and len(currentSentence.strip()) > min_chars_cut:
sentencelist.append(currentSentence.strip())
# otherwise don't cross hard cut
elif EOSX_str not in currentSentence and len(currentSentence.strip()) > min_chars_cut:
sentencelist.append(currentSentence.strip())
currentSentence = ''
totalLength = 0
break
print(f"+ Overlapping blocks: {len(sentencelist)-unique_blocks}")
num_EOS = 0
for i in range(len(sentencelist)):
if eos_to_hc:
sentencelist[i] = sentencelist[i].replace(EOSX_str, EOS_str)
else:
sentencelist[i] = sentencelist[i].replace(EOSX_str, '')
#someone may have had stop strings in the raw text...
sentencelist[i] = sentencelist[i].replace("</s></s>", EOS_str)
num_EOS += sentencelist[i].count(EOS_str)
if num_EOS > 0:
print(f"+ EOS count: {num_EOS}")
#final check for useless lines
sentencelist = [item for item in sentencelist if item.strip() != "</s>"]
sentencelist = [item for item in sentencelist if item.strip() != ""]
if debug_slicer:
# Write the log file
Path('user_data/logs').mkdir(exist_ok=True)
sentencelist_dict = {index: sentence for index, sentence in enumerate(sentencelist)}
output_file = "user_data/logs/sentencelist.json"
with open(output_file, 'w') as f:
json.dump(sentencelist_dict, f,indent=2)
print("Saved sentencelist.json in user_data/logs folder")
return sentencelist
def sliding_block_cut(text: str, min_chars_cut: int, eos_to_hc: bool, cutoff_len: int, hard_cut_string: str, debug_slicer:bool):
EOSX_str = '<//>' #hardcut placeholder
EOS_str = '</s>'
print("Mega Block Overlap: ON")
cut_string = hard_cut_string.replace('\\n', '\n')
text = text.replace(cut_string, EOSX_str)
sentences = split_sentences(text, cutoff_len)
print(f"Sentences: {len(sentences)}")
sentencelist = []
max_cut = cutoff_len-1
#print(f"max_cut: {max_cut}")
advancing_to = 0
prev_block_lastsentence = ""
for i in range(len(sentences)):
totalLength = 0
currentSentence = ''
lastsentence = ""
if i >= advancing_to:
for k in range(i, len(sentences)):
current_length = sentences[k]['size']
if totalLength + current_length <= max_cut and not currentSentence.endswith(EOSX_str):
currentSentence += sentences[k]['text']
totalLength += current_length
lastsentence = sentences[k]['text']
else:
if len(currentSentence.strip()) > min_chars_cut:
if prev_block_lastsentence!=lastsentence:
sentencelist.append(currentSentence.strip())
prev_block_lastsentence = lastsentence
advancing_to = 0
if currentSentence.endswith(EOSX_str):
advancing_to = k
currentSentence = ""
totalLength = 0
break
if currentSentence != "":
if len(currentSentence.strip()) > min_chars_cut:
sentencelist.append(currentSentence.strip())
unique_blocks = len(sentencelist)
print(f"Text Blocks: {unique_blocks}")
num_EOS = 0
for i in range(len(sentencelist)):
if eos_to_hc:
sentencelist[i] = sentencelist[i].replace(EOSX_str, EOS_str)
else:
sentencelist[i] = sentencelist[i].replace(EOSX_str, '')
#someone may have had stop strings in the raw text...
sentencelist[i] = sentencelist[i].replace("</s></s>", EOS_str)
num_EOS += sentencelist[i].count(EOS_str)
if num_EOS > 0:
print(f"+ EOS count: {num_EOS}")
#final check for useless lines
sentencelist = [item for item in sentencelist if item.strip() != "</s>"]
sentencelist = [item for item in sentencelist if item.strip() != ""]
if debug_slicer:
# Write the log file
Path('user_data/logs').mkdir(exist_ok=True)
sentencelist_dict = {index: sentence for index, sentence in enumerate(sentencelist)}
output_file = "user_data/logs/sentencelist.json"
with open(output_file, 'w') as f:
json.dump(sentencelist_dict, f,indent=2)
print("Saved sentencelist.json in user_data/logs folder")
return sentencelist
# Example usage:
# download_file_from_url('https://example.com/path/to/your/file.ext', '/output/directory')
def download_file_from_url(url, overwrite, output_dir_in, valid_extensions = {'.txt', '.json'}):
try:
# Validate and sanitize the URL
#parsed_url = urllib.parse.urlparse(url)
#if not parsed_url.netloc:
# raise ValueError("Invalid URL")
#filename = os.path.basename(parsed_url.path)
# Get the filename from the URL
session = requests.Session()
headers = {}
mode = 'wb'
filename = url.split('/')[-1]
output_dir = str(output_dir_in)
# Construct the full path to the output file
local_filename = os.path.join(output_dir, filename)
# Check if the local file already exists
overw = ''
if os.path.exists(local_filename):
if not overwrite:
yield f"File '{local_filename}' already exists. Aborting."
return
else:
overw = ' [Overwrite existing]'
filename_lower = filename.lower()
# Send an HTTP GET request to the URL with a timeout
file_extension = os.path.splitext(filename_lower)[-1]
if file_extension not in valid_extensions:
yield f"Invalid file extension: {file_extension}. Only {valid_extensions} files are supported."
return
with session.get(url, stream=True, headers=headers, timeout=10) as r:
r.raise_for_status()
# total size can be wildly inaccurate
#total_size = int(r.headers.get('content-length', 0))
block_size = 1024 * 4
with open(local_filename, mode) as f:
count = 0
for data in r.iter_content(block_size):
f.write(data)
count += len(data)
yield f"Downloaded: {count} " + overw
# Verify file size if possible
if os.path.exists(local_filename):
downloaded_size = os.path.getsize(local_filename)
if downloaded_size > 0:
yield f"File '{filename}' downloaded to '{output_dir}' ({downloaded_size} bytes)."
print("File Downloaded")
else:
print("Downloaded file is zero")
yield f"Failed. Downloaded file size is zero)."
else:
print(f"Error: {local_filename} failed to download.")
yield f"Error: {local_filename} failed to download"
except Exception as e:
print(f"An error occurred: {e}")
yield f"An error occurred: {e}"
finally:
# Close the session to release resources
session.close()

View file

@ -2,7 +2,6 @@ from pathlib import Path
import gradio as gr import gradio as gr
import modules.shared as shared
from modules.html_generator import get_image_cache from modules.html_generator import get_image_cache
from modules.shared import gradio from modules.shared import gradio
@ -73,13 +72,13 @@ def generate_html():
global cards global cards
cards = [] cards = []
# Iterate through files in image folder # Iterate through files in image folder
for file in sorted((shared.user_data_dir / "characters").glob("*")): for file in sorted(Path("user_data/characters").glob("*")):
if file.suffix in [".json", ".yml", ".yaml"]: if file.suffix in [".json", ".yml", ".yaml"]:
character = file.stem character = file.stem
container_html = '<div class="character-container">' container_html = '<div class="character-container">'
image_html = "<div class='placeholder'></div>" image_html = "<div class='placeholder'></div>"
for path in [shared.user_data_dir / "characters" / f"{character}.{extension}" for extension in ['png', 'jpg', 'jpeg']]: for path in [Path(f"user_data/characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]:
if path.exists(): if path.exists():
image_html = f'<img src="file/{get_image_cache(path)}">' image_html = f'<img src="file/{get_image_cache(path)}">'
break break

View file

@ -1,20 +1,15 @@
import copy import copy
import functools
import json import json
import time import time
from collections import deque from collections import deque
from pathlib import Path
import tiktoken import tiktoken
import yaml
from pydantic import ValidationError from pydantic import ValidationError
from extensions.openai.errors import InvalidRequestError from extensions.openai.errors import InvalidRequestError
from extensions.openai.typing import ToolDefinition from extensions.openai.typing import ToolDefinition
from extensions.openai.utils import debug_msg from extensions.openai.utils import debug_msg, getToolCallId, parseToolCall
from modules.tool_parsing import get_tool_call_id, parse_tool_call, detect_tool_call_format
from modules import shared from modules import shared
from modules.reasoning import extract_reasoning
from modules.chat import ( from modules.chat import (
generate_chat_prompt, generate_chat_prompt,
generate_chat_reply, generate_chat_reply,
@ -27,126 +22,17 @@ from modules.presets import load_preset_memoized
from modules.text_generation import decode, encode, generate_reply from modules.text_generation import decode, encode, generate_reply
@functools.cache def convert_logprobs_to_tiktoken(model, logprobs):
def load_chat_template_file(filepath): # more problems than it's worth.
"""Load a chat template from a file path (.jinja, .jinja2, or .yaml/.yml).""" # try:
filepath = Path(filepath) # encoder = tiktoken.encoding_for_model(model)
ext = filepath.suffix.lower() # # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
text = filepath.read_text(encoding='utf-8') # return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
if ext in ['.yaml', '.yml']: # except KeyError:
data = yaml.safe_load(text) # # assume native tokens if we can't find the tokenizer
return data.get('instruction_template', '') # return logprobs
return text
return logprobs
def _get_raw_logprob_entries(offset=0):
"""Get raw logprob entries from llama.cpp/ExLlamav3 backend, starting from offset.
Returns (new_entries, new_offset).
"""
if not hasattr(shared.model, 'last_completion_probabilities') or not shared.model.last_completion_probabilities:
return [], offset
all_entries = shared.model.last_completion_probabilities
new_entries = all_entries[offset:]
return new_entries, len(all_entries)
def _dict_to_logprob_entries(token_dict):
"""Convert a flat {token: logprob} dict (from LogprobProcessor) to raw entry format."""
if not token_dict:
return []
return [{"top_logprobs": [{"token": t, "logprob": lp} for t, lp in token_dict.items()]}]
def _parse_entry_top(entry):
"""Extract the top logprobs list from a raw entry, handling both key names."""
return entry.get('top_logprobs', entry.get('top_probs', []))
def format_chat_logprobs(entries):
"""Format logprob entries into OpenAI chat completions logprobs format.
Output: {"content": [{"token", "logprob", "bytes", "top_logprobs": [...]}]}
"""
if not entries:
return None
content = []
for entry in entries:
top = _parse_entry_top(entry)
if not top:
continue
chosen = top[0]
token_str = chosen.get('token', '')
token_logprob = chosen.get('logprob', chosen.get('prob', 0))
top_list = []
for item in top:
t = item.get('token', '')
lp = item.get('logprob', item.get('prob', 0))
top_list.append({
"token": t,
"logprob": lp,
"bytes": list(t.encode('utf-8')) if t else None
})
content.append({
"token": token_str,
"logprob": token_logprob,
"bytes": list(token_str.encode('utf-8')) if token_str else None,
"top_logprobs": top_list
})
return {"content": content, "refusal": None} if content else None
def format_completion_logprobs(entries):
"""Format logprob entries into OpenAI completions logprobs format.
Output: {"tokens", "token_logprobs", "top_logprobs": [{token: prob}], "text_offset"}
"""
if not entries:
return None
tokens = []
token_logprobs = []
top_logprobs = []
text_offset = []
offset = 0
for entry in entries:
top = _parse_entry_top(entry)
if not top:
continue
chosen = top[0]
token_str = chosen.get('token', '')
token_logprob = chosen.get('logprob', chosen.get('prob', 0))
tokens.append(token_str)
token_logprobs.append(token_logprob)
text_offset.append(offset)
offset += len(token_str)
top_dict = {}
for item in top:
t = item.get('token', '')
lp = item.get('logprob', item.get('prob', 0))
top_dict[t] = lp
top_logprobs.append(top_dict)
if not tokens:
return None
return {
"tokens": tokens,
"token_logprobs": token_logprobs,
"top_logprobs": top_logprobs,
"text_offset": text_offset
}
def process_parameters(body, is_legacy=False): def process_parameters(body, is_legacy=False):
@ -171,16 +57,7 @@ def process_parameters(body, is_legacy=False):
elif isinstance(body['stop'], list): elif isinstance(body['stop'], list):
generate_params['custom_stopping_strings'] = body['stop'] generate_params['custom_stopping_strings'] = body['stop']
# Resolve logprobs: for chat completions, logprobs is a bool and the count if shared.args.loader != 'llama.cpp':
# comes from top_logprobs. Normalize to an int for all backends.
logprobs = body.get('logprobs', None)
top_logprobs = body.get('top_logprobs', None)
if logprobs is True:
logprobs = max(top_logprobs, 1) if top_logprobs is not None else 5
generate_params['logprobs'] = logprobs
# For llama.cpp and ExLlamav3 native, logit_bias and logprobs are forwarded natively
if shared.args.loader not in ('llama.cpp', 'ExLlamav3'):
from transformers import LogitsProcessorList from transformers import LogitsProcessorList
from modules.transformers_loader import ( from modules.transformers_loader import (
@ -193,9 +70,13 @@ def process_parameters(body, is_legacy=False):
if logit_bias: # {str: float, ...} if logit_bias: # {str: float, ...}
logits_processor = [LogitsBiasProcessor(logit_bias)] logits_processor = [LogitsBiasProcessor(logit_bias)]
if logprobs is not None and logprobs > 0: logprobs = None # coming to chat eventually
if 'logprobs' in body:
logprobs = body.get('logprobs', 0) # maybe cap at topk? don't clamp 0-5.
generate_params['logprob_proc'] = LogprobProcessor(logprobs) generate_params['logprob_proc'] = LogprobProcessor(logprobs)
logits_processor.extend([generate_params['logprob_proc']]) logits_processor.extend([generate_params['logprob_proc']])
else:
logprobs = None
if logits_processor: # requires logits_processor support if logits_processor: # requires logits_processor support
generate_params['logits_processor'] = LogitsProcessorList(logits_processor) generate_params['logits_processor'] = LogitsProcessorList(logits_processor)
@ -241,58 +122,38 @@ def convert_history(history):
user_input = "" user_input = ""
user_input_last = True user_input_last = True
system_message = "" system_message = ""
seen_non_system = False
for entry in history: for entry in history:
content = entry["content"] content = entry["content"]
role = entry["role"] role = entry["role"]
if role == "user": if role == "user":
seen_non_system = True
# Extract text content (images handled by model-specific code) # Extract text content (images handled by model-specific code)
content = process_multimodal_content(content) content = process_multimodal_content(content)
user_input = content user_input = content
user_input_last = True user_input_last = True
if current_message: if current_message:
chat_dialogue.append([current_message, '', '', {}]) chat_dialogue.append([current_message, '', ''])
current_message = "" current_message = ""
current_message = content current_message = content
elif role == "assistant": elif role == "assistant":
seen_non_system = True if "tool_calls" in entry and isinstance(entry["tool_calls"], list) and len(entry["tool_calls"]) > 0 and content.strip() == "":
meta = {} continue # skip tool calls
tool_calls = entry.get("tool_calls")
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
meta["tool_calls"] = tool_calls
if content.strip() == "":
content = "" # keep empty content, don't skip
current_reply = content current_reply = content
user_input_last = False user_input_last = False
if current_message: if current_message:
chat_dialogue.append([current_message, current_reply, '', meta]) chat_dialogue.append([current_message, current_reply, ''])
current_message = "" current_message = ""
current_reply = "" current_reply = ""
else: else:
chat_dialogue.append(['', current_reply, '', meta]) chat_dialogue.append(['', current_reply, ''])
elif role == "tool": elif role == "tool":
seen_non_system = True
user_input_last = False user_input_last = False
meta = {} chat_dialogue.append(['', '', content])
if "tool_call_id" in entry: elif role == "system":
meta["tool_call_id"] = entry["tool_call_id"]
chat_dialogue.append(['', '', content, meta])
elif role in ("system", "developer"):
if not seen_non_system:
# Leading system messages go to custom_system_message (placed at top)
system_message += f"\n{content}" if system_message else content system_message += f"\n{content}" if system_message else content
else:
# Mid-conversation system messages: preserve position in history
if current_message:
chat_dialogue.append([current_message, '', '', {}])
current_message = ""
chat_dialogue.append([content, '', '', {"role": "system"}])
if not user_input_last: if not user_input_last:
user_input = "" user_input = ""
@ -304,7 +165,7 @@ def convert_history(history):
} }
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, prompt_only=False, stop_event=None) -> dict: def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, prompt_only=False) -> dict:
if body.get('functions', []): if body.get('functions', []):
raise InvalidRequestError(message="functions is not supported.", param='functions') raise InvalidRequestError(message="functions is not supported.", param='functions')
@ -318,10 +179,6 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
if 'tools' in body and body['tools'] is not None and isinstance(body['tools'], list) and len(body['tools']) > 0: if 'tools' in body and body['tools'] is not None and isinstance(body['tools'], list) and len(body['tools']) > 0:
tools = validateTools(body['tools']) # raises InvalidRequestError if validation fails tools = validateTools(body['tools']) # raises InvalidRequestError if validation fails
tool_choice = body.get('tool_choice', None)
if tool_choice == "none":
tools = None # Disable tool detection entirely
messages = body['messages'] messages = body['messages']
for m in messages: for m in messages:
if 'role' not in m: if 'role' not in m:
@ -332,10 +189,6 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
# Handle multimodal content validation # Handle multimodal content validation
content = m.get('content') content = m.get('content')
if content is None: if content is None:
# OpenAI allows content: null on assistant messages when tool_calls is present
if m['role'] == 'assistant' and m.get('tool_calls'):
m['content'] = ''
else:
raise InvalidRequestError(message="messages: missing content", param='messages') raise InvalidRequestError(message="messages: missing content", param='messages')
# Validate multimodal content structure # Validate multimodal content structure
@ -358,8 +211,6 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
# generation parameters # generation parameters
generate_params = process_parameters(body, is_legacy=is_legacy) generate_params = process_parameters(body, is_legacy=is_legacy)
if stop_event is not None:
generate_params['stop_event'] = stop_event
continue_ = body['continue_'] continue_ = body['continue_']
# Instruction template # Instruction template
@ -369,8 +220,6 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
instruction_template = body['instruction_template'] instruction_template = body['instruction_template']
instruction_template = "Alpaca" if instruction_template == "None" else instruction_template instruction_template = "Alpaca" if instruction_template == "None" else instruction_template
instruction_template_str = load_instruction_template_memoized(instruction_template) instruction_template_str = load_instruction_template_memoized(instruction_template)
elif shared.args.chat_template_file:
instruction_template_str = load_chat_template_file(shared.args.chat_template_file)
else: else:
instruction_template_str = shared.settings['instruction_template_str'] instruction_template_str = shared.settings['instruction_template_str']
@ -413,189 +262,106 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
requested_model = generate_params.pop('model') requested_model = generate_params.pop('model')
logprob_proc = generate_params.pop('logprob_proc', None) logprob_proc = generate_params.pop('logprob_proc', None)
if logprob_proc:
logprob_proc.token_alternatives_history.clear()
chat_logprobs_offset = [0] # mutable for closure access in streaming
def chat_streaming_chunk(content=None, chunk_tool_calls=None, include_role=False, reasoning_content=None): def chat_streaming_chunk(content, chunk_tool_calls=None):
# begin streaming # begin streaming
delta = {}
if include_role:
delta['role'] = 'assistant'
delta['refusal'] = None
if content is not None:
delta['content'] = content
if reasoning_content is not None:
delta['reasoning_content'] = reasoning_content
if chunk_tool_calls:
delta['tool_calls'] = chunk_tool_calls
chunk = { chunk = {
"id": cmpl_id, "id": cmpl_id,
"object": object_type, "object": object_type,
"created": created_time, "created": created_time,
"model": shared.model_name, "model": shared.model_name,
"system_fingerprint": None,
resp_list: [{ resp_list: [{
"index": 0, "index": 0,
"finish_reason": None, "finish_reason": None,
"delta": delta, "delta": {'role': 'assistant', 'content': content, 'tool_calls': chunk_tool_calls},
"logprobs": None,
}], }],
} }
if logprob_proc: if logprob_proc: # not official for chat yet
entries = _dict_to_logprob_entries(logprob_proc.token_alternatives) top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
formatted = format_chat_logprobs(entries) chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
if formatted: # else:
chunk[resp_list][0]["logprobs"] = formatted # chunk[resp_list][0]["logprobs"] = None
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
entries, chat_logprobs_offset[0] = _get_raw_logprob_entries(chat_logprobs_offset[0])
if entries:
formatted = format_chat_logprobs(entries)
if formatted:
chunk[resp_list][0]["logprobs"] = formatted
return chunk return chunk
# Check if usage should be included in streaming chunks per OpenAI spec
stream_options = body.get('stream_options')
include_usage = bool(stream_options) and bool(stream_options.get('include_usage') if isinstance(stream_options, dict) else getattr(stream_options, 'include_usage', False))
# generate reply ####################################### # generate reply #######################################
if prompt_only:
prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_) prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_)
if prompt_only:
yield {'prompt': prompt} yield {'prompt': prompt}
return return
if stream: if stream:
chunk = chat_streaming_chunk('', include_role=True) yield chat_streaming_chunk('')
if include_usage:
chunk['usage'] = None
yield chunk
generator = generate_chat_reply( generator = generate_chat_reply(
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False) user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
answer = '' answer = ''
seen_content = '' seen_content = ''
seen_reasoning = ''
tool_calls = [] tool_calls = []
end_last_tool_call = 0 end_last_tool_call = 0
supported_tools = [x["function"]["name"] for x in tools] if tools is not None else None supported_tools = [x["function"]["name"] for x in tools] if tools is not None else None
_tool_parsers = None
# Filter supported_tools when tool_choice specifies a particular function
if supported_tools and isinstance(tool_choice, dict):
specified_func = tool_choice.get("function", {}).get("name")
if specified_func and specified_func in supported_tools:
supported_tools = [specified_func]
if supported_tools is not None:
_template_str = generate_params.get('instruction_template_str', '') if generate_params.get('mode') == 'instruct' else generate_params.get('chat_template_str', '')
_tool_parsers, _, _ = detect_tool_call_format(_template_str)
for a in generator: for a in generator:
answer = a['internal'][-1][1] answer = a['internal'][-1][1]
if supported_tools is not None: if supported_tools is not None:
tool_call = parse_tool_call(answer[end_last_tool_call:], supported_tools, parsers=_tool_parsers) if len(answer) > 0 else [] tool_call = parseToolCall(answer[end_last_tool_call:], supported_tools) if len(answer) > 0 else []
if len(tool_call) > 0: if len(tool_call) > 0:
for tc in tool_call: for tc in tool_call:
tc["id"] = get_tool_call_id() tc["id"] = getToolCallId()
if stream: tc["index"] = str(len(tool_calls))
tc["index"] = len(tool_calls)
tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"]) tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"])
tool_calls.append(tc) tool_calls.append(tc)
end_last_tool_call = len(answer) end_last_tool_call = len(answer)
# Stop generation before streaming content if tool_calls were detected,
# so that raw tool markup is not sent as content deltas.
if len(tool_calls) > 0:
break
if stream: if stream:
# Strip reasoning/thinking blocks so only final content is streamed. len_seen = len(seen_content)
# Reasoning is emitted separately as reasoning_content deltas. new_content = answer[len_seen:]
reasoning, content = extract_reasoning(answer)
if reasoning is not None:
new_reasoning = reasoning[len(seen_reasoning):]
new_content = content[len(seen_content):]
else:
new_reasoning = None
new_content = answer[len(seen_content):]
if (not new_content and not new_reasoning) or chr(0xfffd) in (new_content or '') + (new_reasoning or ''): if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
continue continue
chunk = chat_streaming_chunk( chunk = chat_streaming_chunk(new_content)
content=new_content if new_content else None,
reasoning_content=new_reasoning if new_reasoning else None,
)
if include_usage:
chunk['usage'] = None
if reasoning is not None:
seen_reasoning = reasoning
seen_content = content
else:
seen_content = answer seen_content = answer
yield chunk yield chunk
token_count = shared.model.last_prompt_token_count if hasattr(shared.model, 'last_prompt_token_count') else 0 # stop generation if tool_calls were generated previously
if len(tool_calls) > 0:
break
token_count = len(encode(prompt)[0])
completion_token_count = len(encode(answer)[0]) completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if len(tool_calls) > 0: if len(tool_calls) > 0:
stop_reason = "tool_calls" stop_reason = "tool_calls"
elif token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']: if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
stop_reason = "length" stop_reason = "length"
else:
stop_reason = "stop"
if stream: if stream:
chunk = chat_streaming_chunk(chunk_tool_calls=tool_calls) chunk = chat_streaming_chunk('', tool_calls)
chunk[resp_list][0]['finish_reason'] = stop_reason chunk[resp_list][0]['finish_reason'] = stop_reason
usage = { chunk['usage'] = {
"prompt_tokens": token_count, "prompt_tokens": token_count,
"completion_tokens": completion_token_count, "completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count "total_tokens": token_count + completion_token_count
} }
if include_usage:
chunk['usage'] = None
yield chunk
# Separate usage-only chunk with choices: [] per OpenAI spec
yield {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
"system_fingerprint": None,
resp_list: [],
"usage": usage
}
else:
yield chunk yield chunk
else: else:
reasoning, content = extract_reasoning(answer)
message = {
"role": "assistant",
"refusal": None,
"content": None if tool_calls else content,
**({"reasoning_content": reasoning} if reasoning else {}),
**({"tool_calls": tool_calls} if tool_calls else {}),
}
resp = { resp = {
"id": cmpl_id, "id": cmpl_id,
"object": object_type, "object": object_type,
"created": created_time, "created": created_time,
"model": shared.model_name, "model": shared.model_name,
"system_fingerprint": None,
resp_list: [{ resp_list: [{
"index": 0, "index": 0,
"finish_reason": stop_reason, "finish_reason": stop_reason,
"message": message, "message": {"role": "assistant", "content": answer},
"logprobs": None, "tool_calls": tool_calls
}], }],
"usage": { "usage": {
"prompt_tokens": token_count, "prompt_tokens": token_count,
@ -603,27 +369,19 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
"total_tokens": token_count + completion_token_count "total_tokens": token_count + completion_token_count
} }
} }
if logprob_proc: if logprob_proc: # not official for chat yet
all_entries = [] top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
for alt in logprob_proc.token_alternatives_history: resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
all_entries.extend(_dict_to_logprob_entries(alt)) # else:
formatted = format_chat_logprobs(all_entries) # resp[resp_list][0]["logprobs"] = None
if formatted:
resp[resp_list][0]["logprobs"] = formatted
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
raw = getattr(shared.model, 'last_completion_probabilities', None)
if raw:
formatted = format_chat_logprobs(raw)
if formatted:
resp[resp_list][0]["logprobs"] = formatted
yield resp yield resp
def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_event=None): def completions_common(body: dict, is_legacy: bool = False, stream=False):
object_type = 'text_completion' object_type = 'text_completion.chunk' if stream else 'text_completion'
created_time = int(time.time()) created_time = int(time.time())
cmpl_id = "cmpl-%d" % (int(time.time() * 1000000000)) cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices' resp_list = 'data' if is_legacy else 'choices'
prompt_str = 'context' if is_legacy else 'prompt' prompt_str = 'context' if is_legacy else 'prompt'
@ -653,12 +411,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
generate_params = process_parameters(body, is_legacy=is_legacy) generate_params = process_parameters(body, is_legacy=is_legacy)
max_tokens = generate_params['max_new_tokens'] max_tokens = generate_params['max_new_tokens']
generate_params['stream'] = stream generate_params['stream'] = stream
if stop_event is not None:
generate_params['stop_event'] = stop_event
requested_model = generate_params.pop('model') requested_model = generate_params.pop('model')
logprob_proc = generate_params.pop('logprob_proc', None) logprob_proc = generate_params.pop('logprob_proc', None)
if logprob_proc:
logprob_proc.token_alternatives_history.clear()
suffix = body['suffix'] if body['suffix'] else '' suffix = body['suffix'] if body['suffix'] else ''
echo = body['echo'] echo = body['echo']
@ -670,8 +424,6 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
logger.info(f"Found {len(raw_images)} image(s) in request.") logger.info(f"Found {len(raw_images)} image(s) in request.")
generate_params['raw_images'] = raw_images generate_params['raw_images'] = raw_images
n_completions = body.get('n', 1) or 1
if not stream: if not stream:
prompt_arg = body[prompt_str] prompt_arg = body[prompt_str]
@ -685,7 +437,6 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
resp_list_data = [] resp_list_data = []
total_completion_token_count = 0 total_completion_token_count = 0
total_prompt_token_count = 0 total_prompt_token_count = 0
choice_index = 0
for idx, prompt in enumerate(prompt_arg, start=0): for idx, prompt in enumerate(prompt_arg, start=0):
if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], int): if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], int):
@ -700,17 +451,6 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
prompt = decode(prompt)[0] prompt = decode(prompt)[0]
prefix = prompt if echo else '' prefix = prompt if echo else ''
token_count = len(encode(prompt)[0])
total_prompt_token_count += token_count
original_seed = generate_params.get('seed', -1)
for _n in range(n_completions):
# Increment seed for each completion to ensure diversity (matches llama.cpp native behavior)
if original_seed >= 0:
generate_params['seed'] = original_seed + _n
if logprob_proc:
logprob_proc.token_alternatives_history.clear()
# generate reply ####################################### # generate reply #######################################
debug_msg({'prompt': prompt, 'generate_params': generate_params}) debug_msg({'prompt': prompt, 'generate_params': generate_params})
@ -720,39 +460,28 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
for a in generator: for a in generator:
answer = a answer = a
token_count = len(encode(prompt)[0])
total_prompt_token_count += token_count
completion_token_count = len(encode(answer)[0]) completion_token_count = len(encode(answer)[0])
total_completion_token_count += completion_token_count total_completion_token_count += completion_token_count
stop_reason = "stop" stop_reason = "stop"
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens: if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length" stop_reason = "length"
if logprob_proc:
all_entries = []
for alt in logprob_proc.token_alternatives_history:
all_entries.extend(_dict_to_logprob_entries(alt))
completion_logprobs = format_completion_logprobs(all_entries)
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
raw = getattr(shared.model, 'last_completion_probabilities', None)
completion_logprobs = format_completion_logprobs(raw)
else:
completion_logprobs = None
respi = { respi = {
"index": choice_index, "index": idx,
"finish_reason": stop_reason, "finish_reason": stop_reason,
"text": prefix + answer + suffix, "text": prefix + answer + suffix,
"logprobs": completion_logprobs, "logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
} }
resp_list_data.append(respi) resp_list_data.extend([respi])
choice_index += 1
resp = { resp = {
"id": cmpl_id, "id": cmpl_id,
"object": object_type, "object": object_type,
"created": created_time, "created": created_time,
"model": shared.model_name, "model": shared.model_name,
"system_fingerprint": None,
resp_list: resp_list_data, resp_list: resp_list_data,
"usage": { "usage": {
"prompt_tokens": total_prompt_token_count, "prompt_tokens": total_prompt_token_count,
@ -777,41 +506,24 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
prefix = prompt if echo else '' prefix = prompt if echo else ''
token_count = len(encode(prompt)[0]) token_count = len(encode(prompt)[0])
# Check if usage should be included in streaming chunks per OpenAI spec
stream_options = body.get('stream_options')
include_usage = bool(stream_options) and bool(stream_options.get('include_usage') if isinstance(stream_options, dict) else getattr(stream_options, 'include_usage', False))
cmpl_logprobs_offset = [0] # mutable for closure access in streaming
def text_streaming_chunk(content): def text_streaming_chunk(content):
# begin streaming # begin streaming
if logprob_proc:
chunk_logprobs = format_completion_logprobs(_dict_to_logprob_entries(logprob_proc.token_alternatives))
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
entries, cmpl_logprobs_offset[0] = _get_raw_logprob_entries(cmpl_logprobs_offset[0])
chunk_logprobs = format_completion_logprobs(entries) if entries else None
else:
chunk_logprobs = None
chunk = { chunk = {
"id": cmpl_id, "id": cmpl_id,
"object": object_type, "object": object_type,
"created": created_time, "created": created_time,
"model": shared.model_name, "model": shared.model_name,
"system_fingerprint": None,
resp_list: [{ resp_list: [{
"index": 0, "index": 0,
"finish_reason": None, "finish_reason": None,
"text": content, "text": content,
"logprobs": chunk_logprobs, "logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
}], }],
} }
return chunk return chunk
chunk = text_streaming_chunk(prefix) yield text_streaming_chunk(prefix)
if include_usage:
chunk['usage'] = None
yield chunk
# generate reply ####################################### # generate reply #######################################
debug_msg({'prompt': prompt, 'generate_params': generate_params}) debug_msg({'prompt': prompt, 'generate_params': generate_params})
@ -831,8 +543,6 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
seen_content = answer seen_content = answer
chunk = text_streaming_chunk(new_content) chunk = text_streaming_chunk(new_content)
if include_usage:
chunk['usage'] = None
yield chunk yield chunk
completion_token_count = len(encode(answer)[0]) completion_token_count = len(encode(answer)[0])
@ -842,46 +552,32 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
chunk = text_streaming_chunk(suffix) chunk = text_streaming_chunk(suffix)
chunk[resp_list][0]["finish_reason"] = stop_reason chunk[resp_list][0]["finish_reason"] = stop_reason
usage = { chunk["usage"] = {
"prompt_tokens": token_count, "prompt_tokens": token_count,
"completion_tokens": completion_token_count, "completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count "total_tokens": token_count + completion_token_count
} }
if include_usage:
chunk['usage'] = None
yield chunk
# Separate usage-only chunk with choices: [] per OpenAI spec
yield {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
"system_fingerprint": None,
resp_list: [],
"usage": usage
}
else:
yield chunk yield chunk
def chat_completions(body: dict, is_legacy: bool = False, stop_event=None) -> dict: def chat_completions(body: dict, is_legacy: bool = False) -> dict:
generator = chat_completions_common(body, is_legacy, stream=False, stop_event=stop_event) generator = chat_completions_common(body, is_legacy, stream=False)
return deque(generator, maxlen=1).pop() return deque(generator, maxlen=1).pop()
def stream_chat_completions(body: dict, is_legacy: bool = False, stop_event=None): def stream_chat_completions(body: dict, is_legacy: bool = False):
for resp in chat_completions_common(body, is_legacy, stream=True, stop_event=stop_event): for resp in chat_completions_common(body, is_legacy, stream=True):
yield resp yield resp
def completions(body: dict, is_legacy: bool = False, stop_event=None) -> dict: def completions(body: dict, is_legacy: bool = False) -> dict:
generator = completions_common(body, is_legacy, stream=False, stop_event=stop_event) generator = completions_common(body, is_legacy, stream=False)
return deque(generator, maxlen=1).pop() return deque(generator, maxlen=1).pop()
def stream_completions(body: dict, is_legacy: bool = False, stop_event=None): def stream_completions(body: dict, is_legacy: bool = False):
for resp in completions_common(body, is_legacy, stream=True, stop_event=stop_event): for resp in completions_common(body, is_legacy, stream=True):
yield resp yield resp
@ -892,12 +588,6 @@ def validateTools(tools: list[dict]):
tool = tools[idx] tool = tools[idx]
try: try:
tool_definition = ToolDefinition(**tool) tool_definition = ToolDefinition(**tool)
# Backfill defaults so Jinja2 templates don't crash on missing fields
func = tool.get("function", {})
if "description" not in func:
func["description"] = ""
if "parameters" not in func:
func["parameters"] = {"type": "object", "properties": {}}
if valid_tools is None: if valid_tools is None:
valid_tools = [] valid_tools = []
valid_tools.append(tool) valid_tools.append(tool)

View file

@ -1,4 +1,4 @@
from modules import loaders, shared from modules import shared
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.LoRA import add_lora_to_model from modules.LoRA import add_lora_to_model
from modules.models import load_model, unload_model from modules.models import load_model, unload_model
@ -20,14 +20,10 @@ def list_models():
def list_models_openai_format(): def list_models_openai_format():
"""Returns model list in OpenAI API format""" """Returns model list in OpenAI API format"""
if shared.model_name and shared.model_name != 'None': model_names = get_available_models()
data = [model_info_dict(shared.model_name)]
else:
data = []
return { return {
"object": "list", "object": "list",
"data": data "data": [model_info_dict(name) for name in model_names]
} }
@ -50,14 +46,9 @@ def _load_model(data):
update_model_parameters(model_settings) update_model_parameters(model_settings)
# Update shared.args with custom model loading settings # Update shared.args with custom model loading settings
# Security: only allow keys that correspond to model loading
# parameters exposed in the UI. Never allow security-sensitive
# flags like trust_remote_code or extra_flags to be set via the API.
blocked_keys = {'extra_flags'}
allowed_keys = set(loaders.list_model_elements()) - blocked_keys
if args: if args:
for k in args: for k in args:
if k in allowed_keys and hasattr(shared.args, k): if hasattr(shared.args, k):
setattr(shared.args, k, args[k]) setattr(shared.args, k, args[k])
shared.model, shared.tokenizer = load_model(model_name) shared.model, shared.tokenizer = load_model(model_name)

View file

@ -3,7 +3,6 @@ import json
import logging import logging
import os import os
import socket import socket
import threading
import traceback import traceback
from collections import deque from collections import deque
from threading import Thread from threading import Thread
@ -21,12 +20,11 @@ import extensions.openai.completions as OAIcompletions
import extensions.openai.logits as OAIlogits import extensions.openai.logits as OAIlogits
import extensions.openai.models as OAImodels import extensions.openai.models as OAImodels
from extensions.openai.tokens import token_count, token_decode, token_encode from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.errors import OpenAIError
from extensions.openai.utils import _start_cloudflared from extensions.openai.utils import _start_cloudflared
from modules import shared from modules import shared
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.models import unload_model from modules.models import unload_model
from modules.text_generation import stop_everything_event # used by /v1/internal/stop-generation from modules.text_generation import stop_everything_event
from .typing import ( from .typing import (
ChatCompletionRequest, ChatCompletionRequest,
@ -60,13 +58,8 @@ params = {
} }
async def _wait_for_disconnect(request: Request, stop_event: threading.Event): streaming_semaphore = asyncio.Semaphore(1)
"""Block until the client disconnects, then signal the stop_event.""" image_generation_semaphore = asyncio.Semaphore(1)
while True:
message = await request.receive()
if message["type"] == "http.disconnect":
stop_event.set()
return
def verify_api_key(authorization: str = Header(None)) -> None: def verify_api_key(authorization: str = Header(None)) -> None:
@ -95,20 +88,6 @@ app.add_middleware(
) )
@app.exception_handler(OpenAIError)
async def openai_error_handler(request: Request, exc: OpenAIError):
error_type = "server_error" if exc.code >= 500 else "invalid_request_error"
return JSONResponse(
status_code=exc.code,
content={"error": {
"message": exc.message,
"type": error_type,
"param": getattr(exc, 'param', None),
"code": None
}}
)
@app.middleware("http") @app.middleware("http")
async def validate_host_header(request: Request, call_next): async def validate_host_header(request: Request, call_next):
# Be strict about only approving access to localhost by default # Be strict about only approving access to localhost by default
@ -134,44 +113,29 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
is_legacy = "/generate" in path is_legacy = "/generate" in path
if request_data.stream: if request_data.stream:
if (request_data.n or 1) > 1:
return JSONResponse(
status_code=400,
content={"error": {"message": "n > 1 is not supported with streaming.", "type": "invalid_request_error", "param": "n", "code": None}}
)
stop_event = threading.Event()
async def generator(): async def generator():
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy, stop_event=stop_event) async with streaming_semaphore:
try: try:
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy)
async for resp in iterate_in_threadpool(response): async for resp in iterate_in_threadpool(response):
disconnected = await request.is_disconnected() disconnected = await request.is_disconnected()
if disconnected: if disconnected:
break break
yield {"data": json.dumps(resp)} yield {"data": json.dumps(resp)}
yield {"data": "[DONE]"}
finally: finally:
stop_event.set() stop_everything_event()
response.close() response.close()
return
return EventSourceResponse(generator(), sep="\n") # SSE streaming return EventSourceResponse(generator()) # SSE streaming
else: else:
stop_event = threading.Event()
monitor = asyncio.create_task(_wait_for_disconnect(request, stop_event))
try:
response = await asyncio.to_thread( response = await asyncio.to_thread(
OAIcompletions.completions, OAIcompletions.completions,
to_dict(request_data), to_dict(request_data),
is_legacy=is_legacy, is_legacy=is_legacy
stop_event=stop_event
) )
finally:
stop_event.set()
monitor.cancel()
return JSONResponse(response) return JSONResponse(response)
@ -182,38 +146,29 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
is_legacy = "/generate" in path is_legacy = "/generate" in path
if request_data.stream: if request_data.stream:
stop_event = threading.Event()
async def generator(): async def generator():
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy, stop_event=stop_event) async with streaming_semaphore:
try: try:
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy)
async for resp in iterate_in_threadpool(response): async for resp in iterate_in_threadpool(response):
disconnected = await request.is_disconnected() disconnected = await request.is_disconnected()
if disconnected: if disconnected:
break break
yield {"data": json.dumps(resp)} yield {"data": json.dumps(resp)}
yield {"data": "[DONE]"}
finally: finally:
stop_event.set() stop_everything_event()
response.close() response.close()
return
return EventSourceResponse(generator(), sep="\n") # SSE streaming return EventSourceResponse(generator()) # SSE streaming
else: else:
stop_event = threading.Event()
monitor = asyncio.create_task(_wait_for_disconnect(request, stop_event))
try:
response = await asyncio.to_thread( response = await asyncio.to_thread(
OAIcompletions.chat_completions, OAIcompletions.chat_completions,
to_dict(request_data), to_dict(request_data),
is_legacy=is_legacy, is_legacy=is_legacy
stop_event=stop_event
) )
finally:
stop_event.set()
monitor.cancel()
return JSONResponse(response) return JSONResponse(response)
@ -277,6 +232,7 @@ async def handle_audio_transcription(request: Request):
async def handle_image_generation(request_data: ImageGenerationRequest): async def handle_image_generation(request_data: ImageGenerationRequest):
import extensions.openai.images as OAIimages import extensions.openai.images as OAIimages
async with image_generation_semaphore:
response = await asyncio.to_thread(OAIimages.generations, request_data) response = await asyncio.to_thread(OAIimages.generations, request_data)
return JSONResponse(response) return JSONResponse(response)
@ -401,9 +357,9 @@ async def handle_load_model(request_data: LoadModelRequest):
try: try:
OAImodels._load_model(to_dict(request_data)) OAImodels._load_model(to_dict(request_data))
return JSONResponse(content="OK") return JSONResponse(content="OK")
except Exception: except:
traceback.print_exc() traceback.print_exc()
raise HTTPException(status_code=400, detail="Failed to load the model.") return HTTPException(status_code=400, detail="Failed to load the model.")
@app.post("/v1/internal/model/unload", dependencies=check_admin_key) @app.post("/v1/internal/model/unload", dependencies=check_admin_key)
@ -422,9 +378,9 @@ async def handle_load_loras(request_data: LoadLorasRequest):
try: try:
OAImodels.load_loras(request_data.lora_names) OAImodels.load_loras(request_data.lora_names)
return JSONResponse(content="OK") return JSONResponse(content="OK")
except Exception: except:
traceback.print_exc() traceback.print_exc()
raise HTTPException(status_code=400, detail="Failed to apply the LoRA(s).") return HTTPException(status_code=400, detail="Failed to apply the LoRA(s).")
@app.post("/v1/internal/lora/unload", dependencies=check_admin_key) @app.post("/v1/internal/lora/unload", dependencies=check_admin_key)
@ -458,9 +414,6 @@ def run_server():
# In the server configuration: # In the server configuration:
server_addrs = [] server_addrs = []
if shared.args.listen and shared.args.listen_host:
server_addrs.append(shared.args.listen_host)
else:
if os.environ.get('OPENEDAI_ENABLE_IPV6', shared.args.api_enable_ipv6): if os.environ.get('OPENEDAI_ENABLE_IPV6', shared.args.api_enable_ipv6):
server_addrs.append('[::]' if shared.args.listen else '[::1]') server_addrs.append('[::]' if shared.args.listen else '[::1]')
if not os.environ.get('OPENEDAI_DISABLE_IPV4', shared.args.api_disable_ipv4): if not os.environ.get('OPENEDAI_DISABLE_IPV4', shared.args.api_disable_ipv4):
@ -475,11 +428,11 @@ def run_server():
port, port,
shared.args.public_api_id, shared.args.public_api_id,
max_attempts=3, max_attempts=3,
on_start=lambda url: logger.info(f'OpenAI-compatible API URL:\n\n{url}/v1\n') on_start=lambda url: logger.info(f'OpenAI-compatible API URL:\n\n{url}\n')
) )
else: else:
url_proto = 'https://' if (ssl_certfile and ssl_keyfile) else 'http://' url_proto = 'https://' if (ssl_certfile and ssl_keyfile) else 'http://'
urls = [f'{url_proto}{addr}:{port}/v1' for addr in server_addrs] urls = [f'{url_proto}{addr}:{port}' for addr in server_addrs]
if len(urls) > 1: if len(urls) > 1:
logger.info('OpenAI-compatible API URLs:\n\n' + '\n'.join(urls) + '\n') logger.info('OpenAI-compatible API URLs:\n\n' + '\n'.join(urls) + '\n')
else: else:

View file

@ -1,61 +1,57 @@
import json import json
import time import time
from typing import Any, Dict, List, Optional from typing import Dict, List, Optional
from pydantic import BaseModel, ConfigDict, Field, model_validator, validator from pydantic import BaseModel, Field, model_validator, validator
from modules import shared
class GenerationOptions(BaseModel): class GenerationOptions(BaseModel):
preset: str | None = Field(default=None, description="The name of a file under text-generation-webui/user_data/presets (without the .yaml extension). The sampling parameters that get overwritten by this option are the keys in the default_preset() function in modules/presets.py.") preset: str | None = Field(default=None, description="The name of a file under text-generation-webui/user_data/presets (without the .yaml extension). The sampling parameters that get overwritten by this option are the keys in the default_preset() function in modules/presets.py.")
dynatemp_low: float = shared.args.dynatemp_low dynatemp_low: float = 1
dynatemp_high: float = shared.args.dynatemp_high dynatemp_high: float = 1
dynatemp_exponent: float = shared.args.dynatemp_exponent dynatemp_exponent: float = 1
smoothing_factor: float = shared.args.smoothing_factor smoothing_factor: float = 0
smoothing_curve: float = shared.args.smoothing_curve smoothing_curve: float = 1
min_p: float = shared.args.min_p min_p: float = 0
top_k: int = shared.args.top_k top_k: int = 0
typical_p: float = shared.args.typical_p typical_p: float = 1
xtc_threshold: float = shared.args.xtc_threshold xtc_threshold: float = 0.1
xtc_probability: float = shared.args.xtc_probability xtc_probability: float = 0
epsilon_cutoff: float = shared.args.epsilon_cutoff epsilon_cutoff: float = 0
eta_cutoff: float = shared.args.eta_cutoff eta_cutoff: float = 0
tfs: float = shared.args.tfs tfs: float = 1
top_a: float = shared.args.top_a top_a: float = 0
top_n_sigma: float = shared.args.top_n_sigma top_n_sigma: float = 0
adaptive_target: float = shared.args.adaptive_target dry_multiplier: float = 0
adaptive_decay: float = shared.args.adaptive_decay dry_allowed_length: int = 2
dry_multiplier: float = shared.args.dry_multiplier dry_base: float = 1.75
dry_allowed_length: int = shared.args.dry_allowed_length repetition_penalty: float = 1
dry_base: float = shared.args.dry_base encoder_repetition_penalty: float = 1
repetition_penalty: float = shared.args.repetition_penalty no_repeat_ngram_size: int = 0
encoder_repetition_penalty: float = shared.args.encoder_repetition_penalty repetition_penalty_range: int = 1024
no_repeat_ngram_size: int = shared.args.no_repeat_ngram_size penalty_alpha: float = 0
repetition_penalty_range: int = shared.args.repetition_penalty_range guidance_scale: float = 1
penalty_alpha: float = shared.args.penalty_alpha mirostat_mode: int = 0
guidance_scale: float = shared.args.guidance_scale mirostat_tau: float = 5
mirostat_mode: int = shared.args.mirostat_mode mirostat_eta: float = 0.1
mirostat_tau: float = shared.args.mirostat_tau
mirostat_eta: float = shared.args.mirostat_eta
prompt_lookup_num_tokens: int = 0 prompt_lookup_num_tokens: int = 0
max_tokens_second: int = 0 max_tokens_second: int = 0
do_sample: bool = shared.args.do_sample do_sample: bool = True
dynamic_temperature: bool = shared.args.dynamic_temperature dynamic_temperature: bool = False
temperature_last: bool = shared.args.temperature_last temperature_last: bool = False
auto_max_new_tokens: bool = False auto_max_new_tokens: bool = False
ban_eos_token: bool = False ban_eos_token: bool = False
add_bos_token: bool = True add_bos_token: bool = True
enable_thinking: bool = shared.args.enable_thinking enable_thinking: bool = True
reasoning_effort: str = shared.args.reasoning_effort reasoning_effort: str = "medium"
skip_special_tokens: bool = True skip_special_tokens: bool = True
static_cache: bool = False static_cache: bool = False
truncation_length: int = 0 truncation_length: int = 0
seed: int = -1 seed: int = -1
sampler_priority: List[str] | str | None = Field(default=shared.args.sampler_priority, description="List of samplers where the first items will appear first in the stack. Example: [\"top_k\", \"temperature\", \"top_p\"].") sampler_priority: List[str] | str | None = Field(default=None, description="List of samplers where the first items will appear first in the stack. Example: [\"top_k\", \"temperature\", \"top_p\"].")
custom_token_bans: str = "" custom_token_bans: str = ""
negative_prompt: str = '' negative_prompt: str = ''
dry_sequence_breakers: str = shared.args.dry_sequence_breakers dry_sequence_breakers: str = '"\\n", ":", "\\"", "*"'
grammar_string: str = "" grammar_string: str = ""
@ -65,20 +61,22 @@ class ToolDefinition(BaseModel):
class ToolFunction(BaseModel): class ToolFunction(BaseModel):
model_config = ConfigDict(extra='allow') description: str
description: Optional[str] = None
name: str name: str
parameters: Optional['ToolParameters'] = None parameters: 'ToolParameters'
class ToolParameters(BaseModel): class ToolParameters(BaseModel):
model_config = ConfigDict(extra='allow') properties: Optional[Dict[str, 'ToolProperty']] = None
properties: Optional[Dict[str, Any]] = None
required: Optional[list[str]] = None required: Optional[list[str]] = None
type: str type: str
description: Optional[str] = None description: Optional[str] = None
class ToolProperty(BaseModel):
description: Optional[str] = None
type: Optional[str] = None # we are faced with definitions like anyOf, e.g. {'type': 'function', 'function': {'name': 'git_create_branch', 'description': 'Creates a new branch from an optional base branch', 'parameters': {'type': 'object', 'properties': {'repo_path': {'title': 'Repo Path', 'type': 'string'}, 'branch_name': {'title': 'Branch Name', 'type': 'string'}, 'base_branch': {'anyOf': [{'type': 'string'}, {'type': 'null'}], 'default': None, 'title': 'Base Branch'}}, 'required': ['repo_path', 'branch_name'], 'title': 'GitCreateBranch'}}}
class FunctionCall(BaseModel): class FunctionCall(BaseModel):
name: str name: str
@ -99,28 +97,23 @@ class ToolCall(BaseModel):
function: FunctionCall function: FunctionCall
class StreamOptions(BaseModel):
include_usage: bool | None = False
class CompletionRequestParams(BaseModel): class CompletionRequestParams(BaseModel):
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.") model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
prompt: str | List[str] | None = Field(default=None, description="Text prompt for completion. Can also use 'messages' format for multimodal.") prompt: str | List[str] | None = Field(default=None, description="Text prompt for completion. Can also use 'messages' format for multimodal.")
messages: List[dict] | None = Field(default=None, description="OpenAI messages format for multimodal support. Alternative to 'prompt'.") messages: List[dict] | None = Field(default=None, description="OpenAI messages format for multimodal support. Alternative to 'prompt'.")
best_of: int | None = Field(default=1, description="Unused parameter.") best_of: int | None = Field(default=1, description="Unused parameter.")
echo: bool | None = False echo: bool | None = False
frequency_penalty: float | None = shared.args.frequency_penalty frequency_penalty: float | None = 0
logit_bias: dict | None = None logit_bias: dict | None = None
logprobs: int | None = None logprobs: int | None = None
max_tokens: int | None = 512 max_tokens: int | None = 512
n: int | None = Field(default=1, description="Number of completions to generate. Only supported without streaming.") n: int | None = Field(default=1, description="Unused parameter.")
presence_penalty: float | None = shared.args.presence_penalty presence_penalty: float | None = 0
stop: str | List[str] | None = None stop: str | List[str] | None = None
stream: bool | None = False stream: bool | None = False
stream_options: StreamOptions | None = None
suffix: str | None = None suffix: str | None = None
temperature: float | None = shared.args.temperature temperature: float | None = 1
top_p: float | None = shared.args.top_p top_p: float | None = 1
user: str | None = Field(default=None, description="Unused parameter.") user: str | None = Field(default=None, description="Unused parameter.")
@model_validator(mode='after') @model_validator(mode='after')
@ -146,31 +139,20 @@ class CompletionResponse(BaseModel):
class ChatCompletionRequestParams(BaseModel): class ChatCompletionRequestParams(BaseModel):
messages: List[dict] messages: List[dict]
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.") model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
frequency_penalty: float | None = shared.args.frequency_penalty frequency_penalty: float | None = 0
function_call: str | dict | None = Field(default=None, description="Unused parameter.") function_call: str | dict | None = Field(default=None, description="Unused parameter.")
functions: List[dict] | None = Field(default=None, description="Unused parameter.") functions: List[dict] | None = Field(default=None, description="Unused parameter.")
tools: List[dict] | None = Field(default=None, description="Tools signatures passed via MCP.") tools: List[dict] | None = Field(default=None, description="Tools signatures passed via MCP.")
tool_choice: str | dict | None = Field(default=None, description="Controls tool use: 'auto', 'none', 'required', or {\"type\": \"function\", \"function\": {\"name\": \"...\"}}.")
logit_bias: dict | None = None logit_bias: dict | None = None
logprobs: bool | None = None
top_logprobs: int | None = None
max_tokens: int | None = None max_tokens: int | None = None
max_completion_tokens: int | None = None
n: int | None = Field(default=1, description="Unused parameter.") n: int | None = Field(default=1, description="Unused parameter.")
presence_penalty: float | None = shared.args.presence_penalty presence_penalty: float | None = 0
stop: str | List[str] | None = None stop: str | List[str] | None = None
stream: bool | None = False stream: bool | None = False
stream_options: StreamOptions | None = None temperature: float | None = 1
temperature: float | None = shared.args.temperature top_p: float | None = 1
top_p: float | None = shared.args.top_p
user: str | None = Field(default=None, description="Unused parameter.") user: str | None = Field(default=None, description="Unused parameter.")
@model_validator(mode='after')
def resolve_max_tokens(self):
if self.max_tokens is None and self.max_completion_tokens is not None:
self.max_tokens = self.max_completion_tokens
return self
mode: str = Field(default='instruct', description="Valid options: instruct, chat, chat-instruct.") mode: str = Field(default='instruct', description="Valid options: instruct, chat, chat-instruct.")
instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/user_data/instruction-templates. If not set, the correct template will be automatically obtained from the model metadata.") instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/user_data/instruction-templates. If not set, the correct template will be automatically obtained from the model metadata.")
@ -244,11 +226,11 @@ class LogitsRequestParams(BaseModel):
prompt: str prompt: str
use_samplers: bool = False use_samplers: bool = False
top_logits: int | None = 50 top_logits: int | None = 50
frequency_penalty: float | None = shared.args.frequency_penalty frequency_penalty: float | None = 0
max_tokens: int | None = 512 max_tokens: int | None = 512
presence_penalty: float | None = shared.args.presence_penalty presence_penalty: float | None = 0
temperature: float | None = shared.args.temperature temperature: float | None = 1
top_p: float | None = shared.args.top_p top_p: float | None = 1
class LogitsRequest(GenerationOptions, LogitsRequestParams): class LogitsRequest(GenerationOptions, LogitsRequestParams):

View file

@ -1,5 +1,8 @@
import base64 import base64
import json
import os import os
import random
import re
import time import time
import traceback import traceback
from typing import Callable, Optional from typing import Callable, Optional
@ -52,3 +55,94 @@ def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_star
time.sleep(3) time.sleep(3)
raise Exception('Could not start cloudflared.') raise Exception('Could not start cloudflared.')
def getToolCallId() -> str:
letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789"
b = [random.choice(letter_bytes) for _ in range(8)]
return "call_" + "".join(b).lower()
def checkAndSanitizeToolCallCandidate(candidate_dict: dict, tool_names: list[str]):
# check if property 'function' exists and is a dictionary, otherwise adapt dict
if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str):
candidate_dict = {"type": "function", "function": candidate_dict}
if 'function' in candidate_dict and isinstance(candidate_dict['function'], str):
candidate_dict['name'] = candidate_dict['function']
del candidate_dict['function']
candidate_dict = {"type": "function", "function": candidate_dict}
if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict):
# check if 'name' exists within 'function' and is part of known tools
if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names:
candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value
# map property 'parameters' used by some older models to 'arguments'
if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]:
candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"]
del candidate_dict["function"]["parameters"]
return candidate_dict
return None
def parseToolCall(answer: str, tool_names: list[str]):
matches = []
# abort on very short answers to save computation cycles
if len(answer) < 10:
return matches
# Define the regex pattern to find the JSON content wrapped in <function>, <tools>, <tool_call>, and other tags observed from various models
patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)</\1>"]
for pattern in patterns:
for match in re.finditer(pattern, answer, re.DOTALL):
# print(match.group(2))
if match.group(2) is None:
continue
# remove backtick wraps if present
candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip())
candidate = re.sub(r"```$", "", candidate.strip())
# unwrap inner tags
candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL)
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
if not candidate.strip().startswith("["):
candidate = "[" + candidate + "]"
candidates = []
try:
# parse the candidate JSON into a dictionary
candidates = json.loads(candidate)
if not isinstance(candidates, list):
candidates = [candidates]
except json.JSONDecodeError:
# Ignore invalid JSON silently
continue
for candidate_dict in candidates:
checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names)
if checked_candidate is not None:
matches.append(checked_candidate)
# last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags
if len(matches) == 0:
try:
candidate = answer
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
if not candidate.strip().startswith("["):
candidate = "[" + candidate + "]"
# parse the candidate JSON into a dictionary
candidates = json.loads(candidate)
if not isinstance(candidates, list):
candidates = [candidates]
for candidate_dict in candidates:
checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names)
if checked_candidate is not None:
matches.append(checked_candidate)
except json.JSONDecodeError:
# Ignore invalid JSON silently
pass
return matches

View file

@ -264,7 +264,7 @@ def SD_api_address_update(address):
response = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models') response = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models')
response.raise_for_status() response.raise_for_status()
# r = response.json() # r = response.json()
except Exception: except:
msg = "❌ No SD API endpoint on:" msg = "❌ No SD API endpoint on:"
return gr.Textbox.update(label=msg) return gr.Textbox.update(label=msg)
@ -284,7 +284,7 @@ def get_checkpoints():
options_json = options.json() options_json = options.json()
params['sd_checkpoint'] = options_json['sd_model_checkpoint'] params['sd_checkpoint'] = options_json['sd_model_checkpoint']
params['checkpoint_list'] = [result["title"] for result in models.json()] params['checkpoint_list'] = [result["title"] for result in models.json()]
except Exception: except:
params['sd_checkpoint'] = "" params['sd_checkpoint'] = ""
params['checkpoint_list'] = [] params['checkpoint_list'] = []
@ -298,7 +298,7 @@ def load_checkpoint(checkpoint):
try: try:
requests.post(url=f'{params["address"]}/sdapi/v1/options', json=payload) requests.post(url=f'{params["address"]}/sdapi/v1/options', json=payload)
except Exception: except:
pass pass
@ -307,7 +307,7 @@ def get_samplers():
response = requests.get(url=f'{params["address"]}/sdapi/v1/samplers') response = requests.get(url=f'{params["address"]}/sdapi/v1/samplers')
response.raise_for_status() response.raise_for_status()
samplers = [x["name"] for x in response.json()] samplers = [x["name"] for x in response.json()]
except Exception: except:
samplers = [] samplers = []
return samplers return samplers

View file

@ -2,5 +2,5 @@ beautifulsoup4==4.12.2
chromadb==0.4.24 chromadb==0.4.24
pandas==2.0.3 pandas==2.0.3
posthog==2.4.2 posthog==2.4.2
sentence_transformers==3.3.1 sentence_transformers==2.2.2
lxml lxml

View file

@ -11,11 +11,7 @@ function copyToClipboard(element) {
const rawText = messageElement.getAttribute("data-raw"); const rawText = messageElement.getAttribute("data-raw");
if (!rawText) return; if (!rawText) return;
const copyPromise = navigator.clipboard && window.isSecureContext navigator.clipboard.writeText(rawText).then(function() {
? navigator.clipboard.writeText(rawText)
: fallbackCopyToClipboard(rawText);
copyPromise.then(function() {
const originalSvg = element.innerHTML; const originalSvg = element.innerHTML;
element.innerHTML = "<svg xmlns=\"http://www.w3.org/2000/svg\" width=\"20\" height=\"20\" viewBox=\"0 0 24 24\" fill=\"none\" stroke=\"currentColor\" stroke-width=\"2\" stroke-linecap=\"round\" stroke-linejoin=\"round\" class=\"text-green-500 dark:text-green-400\"><path d=\"M5 12l5 5l10 -10\"></path></svg>"; element.innerHTML = "<svg xmlns=\"http://www.w3.org/2000/svg\" width=\"20\" height=\"20\" viewBox=\"0 0 24 24\" fill=\"none\" stroke=\"currentColor\" stroke-width=\"2\" stroke-linecap=\"round\" stroke-linejoin=\"round\" class=\"text-green-500 dark:text-green-400\"><path d=\"M5 12l5 5l10 -10\"></path></svg>";
setTimeout(() => { setTimeout(() => {
@ -26,27 +22,6 @@ function copyToClipboard(element) {
}); });
} }
function fallbackCopyToClipboard(text) {
return new Promise((resolve, reject) => {
const textArea = document.createElement("textarea");
textArea.value = text;
textArea.style.position = "fixed";
textArea.style.left = "-9999px";
textArea.style.top = "-9999px";
document.body.appendChild(textArea);
textArea.focus();
textArea.select();
try {
const successful = document.execCommand("copy");
document.body.removeChild(textArea);
successful ? resolve() : reject();
} catch (err) {
document.body.removeChild(textArea);
reject(err);
}
});
}
function branchHere(element) { function branchHere(element) {
if (!element) return; if (!element) return;
@ -269,49 +244,7 @@ function removeLastClick() {
document.getElementById("Remove-last").click(); document.getElementById("Remove-last").click();
} }
function autoScrollToBottom() {
if (!window.isScrolled) {
const chatParent = document.getElementById("chat")?.parentNode?.parentNode?.parentNode;
if (chatParent) {
const maxScroll = chatParent.scrollHeight - chatParent.clientHeight;
if (maxScroll > 0 && chatParent.scrollTop < maxScroll - 1) {
chatParent.scrollTop = maxScroll;
}
}
}
}
function updateInstructPadding() {
const chatElement = document.getElementById("chat");
if (chatElement && chatElement.getAttribute("data-mode") === "instruct") {
const messagesContainer = chatElement.querySelector(".messages");
const lastChild = messagesContainer?.lastElementChild;
const prevSibling = lastChild?.previousElementSibling;
if (lastChild && prevSibling && chatElement.offsetHeight > 0) {
let bufferHeight = Math.max(0, Math.max(window.innerHeight - 128 - 84, window.innerHeight - prevSibling.offsetHeight - 84) - lastChild.offsetHeight);
if (window.innerWidth <= 924) {
bufferHeight = Math.max(0, bufferHeight - 32);
}
messagesContainer.style.paddingBottom = `${bufferHeight}px`;
}
}
}
let pendingMorphdomData = null;
let morphdomRafId = null;
function handleMorphdomUpdate(data) { function handleMorphdomUpdate(data) {
pendingMorphdomData = data;
if (!morphdomRafId) {
morphdomRafId = requestAnimationFrame(() => {
morphdomRafId = null;
applyMorphdomUpdate(pendingMorphdomData);
pendingMorphdomData = null;
});
}
}
function applyMorphdomUpdate(data) {
// Determine target element and use it as query scope // Determine target element and use it as query scope
var target_element, target_html; var target_element, target_html;
if (data.last_message_only) { if (data.last_message_only) {
@ -325,22 +258,28 @@ function applyMorphdomUpdate(data) {
const queryScope = target_element; const queryScope = target_element;
// Track open blocks and store their scroll positions // Track open blocks
const openBlocks = new Set(); const openBlocks = new Set();
const scrollPositions = {};
queryScope.querySelectorAll(".thinking-block").forEach(block => { queryScope.querySelectorAll(".thinking-block").forEach(block => {
const blockId = block.getAttribute("data-block-id"); const blockId = block.getAttribute("data-block-id");
// If block exists and is open, add to open set
if (blockId && block.hasAttribute("open")) { if (blockId && block.hasAttribute("open")) {
openBlocks.add(blockId); openBlocks.add(blockId);
}
});
// Store scroll positions for any open blocks
const scrollPositions = {};
queryScope.querySelectorAll(".thinking-block[open]").forEach(block => {
const content = block.querySelector(".thinking-content"); const content = block.querySelector(".thinking-content");
if (content) { const blockId = block.getAttribute("data-block-id");
if (content && blockId) {
const isAtBottom = Math.abs((content.scrollHeight - content.scrollTop) - content.clientHeight) < 5; const isAtBottom = Math.abs((content.scrollHeight - content.scrollTop) - content.clientHeight) < 5;
scrollPositions[blockId] = { scrollPositions[blockId] = {
position: content.scrollTop, position: content.scrollTop,
isAtBottom: isAtBottom isAtBottom: isAtBottom
}; };
} }
}
}); });
morphdom( morphdom(
@ -349,8 +288,8 @@ function applyMorphdomUpdate(data) {
{ {
onBeforeElUpdated: function(fromEl, toEl) { onBeforeElUpdated: function(fromEl, toEl) {
// Preserve code highlighting // Preserve code highlighting
if (fromEl.tagName === "PRE") { if (fromEl.tagName === "PRE" && fromEl.querySelector("code[data-highlighted]")) {
const fromCode = fromEl.querySelector("code[data-highlighted]"); const fromCode = fromEl.querySelector("code");
const toCode = toEl.querySelector("code"); const toCode = toEl.querySelector("code");
if (fromCode && toCode && fromCode.textContent === toCode.textContent) { if (fromCode && toCode && fromCode.textContent === toCode.textContent) {
@ -395,23 +334,10 @@ function applyMorphdomUpdate(data) {
} }
); );
// Syntax highlighting and LaTeX
if (window.doSyntaxHighlighting) {
window.doSyntaxHighlighting();
}
// Auto-scroll runs both before and after padding update.
// Before: so content growth isn't hidden by padding absorption.
// After: so padding-added space is also scrolled into view.
autoScrollToBottom();
updateInstructPadding();
autoScrollToBottom();
// Add toggle listeners for new blocks // Add toggle listeners for new blocks
queryScope.querySelectorAll(".thinking-block").forEach(block => { queryScope.querySelectorAll(".thinking-block").forEach(block => {
if (!block._hasToggleListener) { if (!block._hasToggleListener) {
block.addEventListener("toggle", function(e) { block.addEventListener("toggle", function(e) {
const wasScrolled = window.isScrolled;
if (this.open) { if (this.open) {
const content = this.querySelector(".thinking-content"); const content = this.querySelector(".thinking-content");
if (content) { if (content) {
@ -420,14 +346,44 @@ function applyMorphdomUpdate(data) {
}, 0); }, 0);
} }
} }
autoScrollToBottom();
updateInstructPadding();
autoScrollToBottom();
// Restore scroll state so the browser's layout adjustment
// from the toggle doesn't disable auto-scroll
window.isScrolled = wasScrolled;
}); });
block._hasToggleListener = true; block._hasToggleListener = true;
} }
}); });
} }
// Wait for Gradio to finish setting its styles, then force dark theme
const observer = new MutationObserver((mutations) => {
mutations.forEach((mutation) => {
if (mutation.type === "attributes" &&
mutation.target.tagName === "GRADIO-APP" &&
mutation.attributeName === "style") {
// Gradio just set its styles, now force dark theme
document.body.classList.add("dark");
observer.disconnect();
}
});
});
// Start observing
observer.observe(document.documentElement, {
attributes: true,
subtree: true,
attributeFilter: ["style"]
});
//------------------------------------------------
// Suppress "Attempted to select a non-interactive or hidden tab" warning
//------------------------------------------------
(function() {
const originalWarn = console.warn;
console.warn = function(...args) {
if (args[0] && typeof args[0] === "string" && args[0].includes("Attempted to select a non-interactive or hidden tab")) {
return;
}
originalWarn.apply(console, args);
};
})();

View file

@ -1,84 +1 @@
function fallbackCopyToClipboard(text) { class CopyButtonPlugin{constructor(options={}){self.hook=options.hook;self.callback=options.callback;self.lang=options.lang||document.documentElement.lang||"en"}"after:highlightElement"({el,text}){let button=Object.assign(document.createElement("button"),{innerHTML:locales[lang]?.[0]||"Copy",className:"hljs-copy-button"});button.dataset.copied=false;el.parentElement.classList.add("hljs-copy-wrapper");el.parentElement.appendChild(button);el.parentElement.style.setProperty("--hljs-theme-background",window.getComputedStyle(el).backgroundColor);button.onclick=function(){if(!navigator.clipboard)return;let newText=text;if(hook&&typeof hook==="function"){newText=hook(text,el)||text}navigator.clipboard.writeText(newText).then(function(){button.innerHTML=locales[lang]?.[1]||"Copied!";button.dataset.copied=true;let alert=Object.assign(document.createElement("div"),{role:"status",className:"hljs-copy-alert",innerHTML:locales[lang]?.[2]||"Copied to clipboard"});el.parentElement.appendChild(alert);setTimeout(()=>{button.innerHTML=locales[lang]?.[0]||"Copy";button.dataset.copied=false;el.parentElement.removeChild(alert);alert=null},2e3)}).then(function(){if(typeof callback==="function")return callback(newText,el)})}}}if(typeof module!="undefined"){module.exports=CopyButtonPlugin}const locales={en:["Copy","Copied!","Copied to clipboard"],es:["Copiar","¡Copiado!","Copiado al portapapeles"],fr:["Copier","Copié !","Copié dans le presse-papier"],de:["Kopieren","Kopiert!","In die Zwischenablage kopiert"],ja:["コピー","コピーしました!","クリップボードにコピーしました"],ko:["복사","복사됨!","클립보드에 복사됨"],ru:["Копировать","Скопировано!","Скопировано в буфер обмена"],zh:["复制","已复制!","已复制到剪贴板"],"zh-tw":["複製","已複製!","已複製到剪貼簿"]};
return new Promise((resolve, reject) => {
const textArea = document.createElement("textarea");
textArea.value = text;
textArea.style.position = "fixed";
textArea.style.left = "-9999px";
textArea.style.top = "-9999px";
document.body.appendChild(textArea);
textArea.focus();
textArea.select();
try {
const successful = document.execCommand("copy");
document.body.removeChild(textArea);
successful ? resolve() : reject();
} catch (err) {
document.body.removeChild(textArea);
reject(err);
}
});
}
class CopyButtonPlugin {
constructor(options = {}) {
self.hook = options.hook;
self.callback = options.callback;
self.lang = options.lang || document.documentElement.lang || "en";
}
"after:highlightElement"({ el, text }) {
let button = Object.assign(document.createElement("button"), {
innerHTML: locales[lang]?.[0] || "Copy",
className: "hljs-copy-button",
});
button.dataset.copied = false;
el.parentElement.classList.add("hljs-copy-wrapper");
el.parentElement.appendChild(button);
el.parentElement.style.setProperty(
"--hljs-theme-background",
window.getComputedStyle(el).backgroundColor,
);
button.onclick = function () {
let newText = text;
if (hook && typeof hook === "function") {
newText = hook(text, el) || text;
}
const copyPromise =
navigator.clipboard && window.isSecureContext
? navigator.clipboard.writeText(newText)
: fallbackCopyToClipboard(newText);
copyPromise.then(function () {
button.innerHTML = locales[lang]?.[1] || "Copied!";
button.dataset.copied = true;
let alert = Object.assign(document.createElement("div"), {
role: "status",
className: "hljs-copy-alert",
innerHTML: locales[lang]?.[2] || "Copied to clipboard",
});
el.parentElement.appendChild(alert);
setTimeout(() => {
button.innerHTML = locales[lang]?.[0] || "Copy";
button.dataset.copied = false;
el.parentElement.removeChild(alert);
alert = null;
}, 2e3);
})
.then(function () {
if (typeof callback === "function") return callback(newText, el);
});
};
}
}
if (typeof module != "undefined") {
module.exports = CopyButtonPlugin;
}
const locales = {
en: ["Copy", "Copied!", "Copied to clipboard"],
es: ["Copiar", "¡Copiado!", "Copiado al portapapeles"],
fr: ["Copier", "Copié !", "Copié dans le presse-papier"],
de: ["Kopieren", "Kopiert!", "In die Zwischenablage kopiert"],
ja: ["コピー", "コピーしました!", "クリップボードにコピーしました"],
ko: ["복사", "복사됨!", "클립보드에 복사됨"],
ru: ["Копировать", "Скопировано!", "Скопировано в буфер обмена"],
zh: ["复制", "已复制!", "已复制到剪贴板"],
"zh-tw": ["複製", "已複製!", "已複製到剪貼簿"],
};

View file

@ -1,184 +0,0 @@
! function(e, t) {
"object" == typeof exports && "object" == typeof module ? module.exports = t(require("katex")) : "function" == typeof define && define.amd ? define(["katex"], t) : "object" == typeof exports ? exports.renderMathInElement = t(require("katex")) : e.renderMathInElement = t(e.katex)
}("undefined" != typeof self ? self : this, (function(e) {
return function() {
"use strict";
var t = {
771: function(t) {
t.exports = e
}
},
n = {};
function r(e) {
var o = n[e];
if (void 0 !== o) return o.exports;
var i = n[e] = {
exports: {}
};
return t[e](i, i.exports, r), i.exports
}
r.n = function(e) {
var t = e && e.__esModule ? function() {
return e.default
} : function() {
return e
};
return r.d(t, {
a: t
}), t
}, r.d = function(e, t) {
for (var n in t) r.o(t, n) && !r.o(e, n) && Object.defineProperty(e, n, {
enumerable: !0,
get: t[n]
})
}, r.o = function(e, t) {
return Object.prototype.hasOwnProperty.call(e, t)
};
var o = {};
return function() {
r.d(o, {
default: function() {
return d
}
});
var e = r(771),
t = r.n(e);
const n = function(e, t, n) {
let r = n,
o = 0;
const i = e.length;
for (; r < t.length;) {
const n = t[r];
if (o <= 0 && t.slice(r, r + i) === e) return r;
"\\" === n ? r++ : "{" === n ? o++ : "}" === n && o--, r++
}
return -1
},
i = /^\\begin{/;
var a = function(e, t) {
let r;
const o = [],
a = new RegExp("(" + t.map((e => e.left.replace(/[-/\\^$*+?.()|[\]{}]/g, "\\$&"))).join("|") + ")");
for (; r = e.search(a), -1 !== r;) {
const charAfterOpen = e[r + 1];
if (e[r] == "$" && charAfterOpen != "$") {
const closeDollarIndex = e.indexOf('$', r + 1);
if (closeDollarIndex != -1) {
const charBeforeOpen = r > 0 ? e[r - 1] : '';
const charBeforeClose = r + 1 < closeDollarIndex ? e[closeDollarIndex - 1] : '';
const charBeforeBeforeClose = r + 1 < closeDollarIndex ? e[closeDollarIndex - 2] : '';
const charAfterClose = closeDollarIndex + 1 < e.length ? e[closeDollarIndex + 1] : '';
if ((/[A-Za-z0-9_$-]/.test(charBeforeOpen)) || ((' ' == charBeforeClose) ||
/[0-9]/.test(charAfterOpen) &&
(/[A-Za-z0-9]/.test(charAfterClose)
|| '-' == charBeforeClose))) {
o.push({
type: "text",
data: e.slice(0, r + 1),
});
e = e.slice(r + 1); // now text starts after delimiter
continue;
}
}
}
r > 0 && (o.push({
type: "text",
data: e.slice(0, r)
}), e = e.slice(r));
const a = t.findIndex((t => e.startsWith(t.left)));
if (r = n(t[a].right, e, t[a].left.length), -1 === r) break;
const l = e.slice(0, r + t[a].right.length),
s = i.test(l) ? l : e.slice(t[a].left.length, r);
o.push({
type: "math",
data: s,
rawData: l,
display: t[a].display
}), e = e.slice(r + t[a].right.length)
}
return "" !== e && o.push({
type: "text",
data: e
}), o
};
const l = function(e, n) {
const r = a(e, n.delimiters);
if (1 === r.length && "text" === r[0].type) return null;
const o = document.createDocumentFragment();
for (let e = 0; e < r.length; e++)
if ("text" === r[e].type) o.appendChild(document.createTextNode(r[e].data));
else {
const i = document.createElement("span");
let a = r[e].data;
n.displayMode = r[e].display;
try {
n.preProcess && (a = n.preProcess(a)), t().render(a, i, n)
} catch (i) {
if (!(i instanceof t().ParseError)) throw i;
n.errorCallback("KaTeX auto-render: Failed to parse `" + r[e].data + "` with ", i), o.appendChild(document.createTextNode(r[e].rawData));
continue
}
o.appendChild(i)
}
return o
},
s = function(e, t) {
for (let n = 0; n < e.childNodes.length; n++) {
const r = e.childNodes[n];
if (3 === r.nodeType) {
let o = r.textContent,
i = r.nextSibling,
a = 0;
for (; i && i.nodeType === Node.TEXT_NODE;) o += i.textContent, i = i.nextSibling, a++;
const s = l(o, t);
if (s) {
for (let e = 0; e < a; e++) r.nextSibling.remove();
n += s.childNodes.length - 1, e.replaceChild(s, r)
} else n += a
} else if (1 === r.nodeType) {
const e = " " + r.className + " "; - 1 === t.ignoredTags.indexOf(r.nodeName.toLowerCase()) && t.ignoredClasses.every((t => -1 === e.indexOf(" " + t + " "))) && s(r, t)
}
}
};
var d = function(e, t) {
if (!e) throw new Error("No element provided to render");
const n = {};
for (const e in t) t.hasOwnProperty(e) && (n[e] = t[e]);
n.delimiters = n.delimiters || [{
left: "$$",
right: "$$",
display: !0
}, {
left: "\\(",
right: "\\)",
display: !1
}, {
left: "\\begin{equation}",
right: "\\end{equation}",
display: !0
}, {
left: "\\begin{align}",
right: "\\end{align}",
display: !0
}, {
left: "\\begin{alignat}",
right: "\\end{alignat}",
display: !0
}, {
left: "\\begin{gather}",
right: "\\end{gather}",
display: !0
}, {
left: "\\begin{CD}",
right: "\\end{CD}",
display: !0
}, {
left: "\\[",
right: "\\]",
display: !0
}], n.ignoredTags = n.ignoredTags || ["script", "noscript", "style", "textarea", "pre", "code", "option"], n.ignoredClasses = n.ignoredClasses || [], n.errorCallback = n.errorCallback || console.error, n.macros = n.macros || {}, s(e, n)
}
}(), o = o.default
}()
}));

1
js/katex/auto-render.min.js vendored Normal file
View file

@ -0,0 +1 @@
!function(e,t){"object"==typeof exports&&"object"==typeof module?module.exports=t(require("katex")):"function"==typeof define&&define.amd?define(["katex"],t):"object"==typeof exports?exports.renderMathInElement=t(require("katex")):e.renderMathInElement=t(e.katex)}("undefined"!=typeof self?self:this,(function(e){return function(){"use strict";var t={771:function(t){t.exports=e}},n={};function r(e){var o=n[e];if(void 0!==o)return o.exports;var i=n[e]={exports:{}};return t[e](i,i.exports,r),i.exports}r.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return r.d(t,{a:t}),t},r.d=function(e,t){for(var n in t)r.o(t,n)&&!r.o(e,n)&&Object.defineProperty(e,n,{enumerable:!0,get:t[n]})},r.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)};var o={};return function(){r.d(o,{default:function(){return d}});var e=r(771),t=r.n(e);const n=function(e,t,n){let r=n,o=0;const i=e.length;for(;r<t.length;){const n=t[r];if(o<=0&&t.slice(r,r+i)===e)return r;"\\"===n?r++:"{"===n?o++:"}"===n&&o--,r++}return-1},i=/^\\begin{/;var a=function(e,t){let r;const o=[],a=new RegExp("("+t.map((e=>e.left.replace(/[-/\\^$*+?.()|[\]{}]/g,"\\$&"))).join("|")+")");for(;r=e.search(a),-1!==r;){r>0&&(o.push({type:"text",data:e.slice(0,r)}),e=e.slice(r));const a=t.findIndex((t=>e.startsWith(t.left)));if(r=n(t[a].right,e,t[a].left.length),-1===r)break;const l=e.slice(0,r+t[a].right.length),s=i.test(l)?l:e.slice(t[a].left.length,r);o.push({type:"math",data:s,rawData:l,display:t[a].display}),e=e.slice(r+t[a].right.length)}return""!==e&&o.push({type:"text",data:e}),o};const l=function(e,n){const r=a(e,n.delimiters);if(1===r.length&&"text"===r[0].type)return null;const o=document.createDocumentFragment();for(let e=0;e<r.length;e++)if("text"===r[e].type)o.appendChild(document.createTextNode(r[e].data));else{const i=document.createElement("span");let a=r[e].data;n.displayMode=r[e].display;try{n.preProcess&&(a=n.preProcess(a)),t().render(a,i,n)}catch(i){if(!(i instanceof t().ParseError))throw i;n.errorCallback("KaTeX auto-render: Failed to parse `"+r[e].data+"` with ",i),o.appendChild(document.createTextNode(r[e].rawData));continue}o.appendChild(i)}return o},s=function(e,t){for(let n=0;n<e.childNodes.length;n++){const r=e.childNodes[n];if(3===r.nodeType){let o=r.textContent,i=r.nextSibling,a=0;for(;i&&i.nodeType===Node.TEXT_NODE;)o+=i.textContent,i=i.nextSibling,a++;const s=l(o,t);if(s){for(let e=0;e<a;e++)r.nextSibling.remove();n+=s.childNodes.length-1,e.replaceChild(s,r)}else n+=a}else if(1===r.nodeType){const e=" "+r.className+" ";-1===t.ignoredTags.indexOf(r.nodeName.toLowerCase())&&t.ignoredClasses.every((t=>-1===e.indexOf(" "+t+" ")))&&s(r,t)}}};var d=function(e,t){if(!e)throw new Error("No element provided to render");const n={};for(const e in t)t.hasOwnProperty(e)&&(n[e]=t[e]);n.delimiters=n.delimiters||[{left:"$$",right:"$$",display:!0},{left:"\\(",right:"\\)",display:!1},{left:"\\begin{equation}",right:"\\end{equation}",display:!0},{left:"\\begin{align}",right:"\\end{align}",display:!0},{left:"\\begin{alignat}",right:"\\end{alignat}",display:!0},{left:"\\begin{gather}",right:"\\end{gather}",display:!0},{left:"\\begin{CD}",right:"\\end{CD}",display:!0},{left:"\\[",right:"\\]",display:!0}],n.ignoredTags=n.ignoredTags||["script","noscript","style","textarea","pre","code","option"],n.ignoredClasses=n.ignoredClasses||[],n.errorCallback=n.errorCallback||console.error,n.macros=n.macros||{},s(e,n)}}(),o=o.default}()}));

View file

@ -2,12 +2,6 @@
// Main // Main
// ------------------------------------------------ // ------------------------------------------------
// Sync highlight.js theme with the actual Gradio theme
var defined_hljs_css = document.body.classList.contains("dark") ? "file/css/highlightjs/github-dark.min.css" : "file/css/highlightjs/github.min.css";
if (document.getElementById("highlight-css").getAttribute("href") !== defined_hljs_css) {
document.getElementById("highlight-css").setAttribute("href", defined_hljs_css);
}
let main_parent = document.getElementById("chat-tab").parentNode; let main_parent = document.getElementById("chat-tab").parentNode;
let extensions = document.getElementById("extensions"); let extensions = document.getElementById("extensions");
@ -151,13 +145,10 @@ targetElement.classList.add("pretty_scrollbar");
targetElement.classList.add("chat-parent"); targetElement.classList.add("chat-parent");
window.isScrolled = false; window.isScrolled = false;
let scrollTimeout; let scrollTimeout;
let lastScrollTop = 0;
let lastScrollHeight = 0;
let lastClientHeight = 0;
targetElement.addEventListener("scroll", function() { targetElement.addEventListener("scroll", function() {
let diff = targetElement.scrollHeight - targetElement.clientHeight; let diff = targetElement.scrollHeight - targetElement.clientHeight;
let isAtBottomNow = Math.abs(targetElement.scrollTop - diff) <= 10 || diff <= 0; let isAtBottomNow = Math.abs(targetElement.scrollTop - diff) <= 10 || diff == 0;
// Add scrolling class to disable hover effects // Add scrolling class to disable hover effects
if (window.isScrolled || !isAtBottomNow) { if (window.isScrolled || !isAtBottomNow) {
@ -166,12 +157,9 @@ targetElement.addEventListener("scroll", function() {
if(isAtBottomNow) { if(isAtBottomNow) {
window.isScrolled = false; window.isScrolled = false;
} else if (targetElement.scrollTop < lastScrollTop && targetElement.scrollHeight >= lastScrollHeight && targetElement.clientHeight <= lastClientHeight) { } else {
window.isScrolled = true; window.isScrolled = true;
} }
lastScrollTop = targetElement.scrollTop;
lastScrollHeight = targetElement.scrollHeight;
lastClientHeight = targetElement.clientHeight;
// Clear previous timeout and set new one // Clear previous timeout and set new one
clearTimeout(scrollTimeout); clearTimeout(scrollTimeout);
@ -182,28 +170,61 @@ targetElement.addEventListener("scroll", function() {
}); });
// Create a MutationObserver instance // Create a MutationObserver instance
const observer = new MutationObserver(function() { const observer = new MutationObserver(function(mutations) {
// Check if this is just the scrolling class being toggled
const isScrollingClassOnly = mutations.every(mutation =>
mutation.type === "attributes" &&
mutation.attributeName === "class" &&
mutation.target === targetElement
);
if (targetElement.classList.contains("_generating")) { if (targetElement.classList.contains("_generating")) {
typing.parentNode.classList.add("visible-dots"); typing.parentNode.classList.add("visible-dots");
document.getElementById("stop").style.display = "flex"; document.getElementById("stop").style.display = "flex";
document.getElementById("Generate").style.display = "none"; document.getElementById("Generate").style.display = "none";
// If the user is near the bottom, ensure auto-scroll is enabled
// for the new reply. This catches cases where isScrolled was
// incorrectly set to true by layout shifts during page load, etc.
const diff = targetElement.scrollHeight - targetElement.clientHeight;
if (Math.abs(targetElement.scrollTop - diff) <= 10 || diff <= 0) {
window.isScrolled = false;
}
} else { } else {
typing.parentNode.classList.remove("visible-dots"); typing.parentNode.classList.remove("visible-dots");
document.getElementById("stop").style.display = "none"; document.getElementById("stop").style.display = "none";
document.getElementById("Generate").style.display = "flex"; document.getElementById("Generate").style.display = "flex";
} }
doSyntaxHighlighting();
if (!window.isScrolled && !isScrollingClassOnly) {
const maxScroll = targetElement.scrollHeight - targetElement.clientHeight;
if (maxScroll > 0 && targetElement.scrollTop < maxScroll - 1) {
targetElement.scrollTop = maxScroll;
}
}
const chatElement = document.getElementById("chat");
if (chatElement && chatElement.getAttribute("data-mode") === "instruct") {
const messagesContainer = chatElement.querySelector(".messages");
const lastChild = messagesContainer?.lastElementChild;
const prevSibling = lastChild?.previousElementSibling;
if (lastChild && prevSibling) {
// Add padding to the messages container to create room for the last message.
// The purpose of this is to avoid constant scrolling during streaming in
// instruct mode.
let bufferHeight = Math.max(0, Math.max(window.innerHeight - 128 - 84, window.innerHeight - prevSibling.offsetHeight - 84) - lastChild.offsetHeight);
// Subtract header height when screen width is <= 924px
if (window.innerWidth <= 924) {
bufferHeight = Math.max(0, bufferHeight - 32);
}
messagesContainer.style.paddingBottom = `${bufferHeight}px`;
}
}
}); });
// Only watch for attribute changes on targetElement (e.g. _generating class) // Configure the observer to watch for changes in the subtree and attributes
const config = { const config = {
attributes: true childList: true,
subtree: true,
characterData: true,
attributeOldValue: true,
characterDataOldValue: true
}; };
// Start observing the target element // Start observing the target element
@ -222,10 +243,13 @@ function isElementVisibleOnScreen(element) {
); );
} }
window.doSyntaxHighlighting = function() { function doSyntaxHighlighting() {
const messageBodies = document.getElementById("chat").querySelectorAll(".message-body"); const messageBodies = document.getElementById("chat").querySelectorAll(".message-body");
if (messageBodies.length > 0) { if (messageBodies.length > 0) {
observer.disconnect();
try {
let hasSeenVisible = false; let hasSeenVisible = false;
// Go from last message to first // Go from last message to first
@ -250,7 +274,6 @@ window.doSyntaxHighlighting = function() {
renderMathInElement(container, { renderMathInElement(container, {
delimiters: [ delimiters: [
{ left: "$$", right: "$$", display: true }, { left: "$$", right: "$$", display: true },
{ left: "$", right: "$", display: false },
{ left: "\\(", right: "\\)", display: false }, { left: "\\(", right: "\\)", display: false },
{ left: "\\[", right: "\\]", display: true }, { left: "\\[", right: "\\]", display: true },
], ],
@ -263,35 +286,20 @@ window.doSyntaxHighlighting = function() {
break; break;
} }
} }
} finally {
observer.observe(targetElement, config);
}
} }
} }
const doSyntaxHighlighting = window.doSyntaxHighlighting;
//------------------------------------------------ //------------------------------------------------
// Add some scrollbars // Add some scrollbars
//------------------------------------------------ //------------------------------------------------
const scrollbarElements = document.querySelectorAll(".add_scrollbar textarea, .add_scrollbar .drag-drop-list"); const textareaElements = document.querySelectorAll(".add_scrollbar textarea");
for(i = 0; i < scrollbarElements.length; i++) { for(i = 0; i < textareaElements.length; i++) {
scrollbarElements[i].classList.remove("scroll-hide"); textareaElements[i].classList.remove("scroll-hide");
scrollbarElements[i].classList.add("pretty_scrollbar"); textareaElements[i].classList.add("pretty_scrollbar");
scrollbarElements[i].style.resize = "none"; textareaElements[i].style.resize = "none";
}
//------------------------------------------------
// Tools: inject "Refresh list" link into the label
//------------------------------------------------
const toolsTitle = document.querySelector("#tools-group > [data-testid='block-info']");
const toolsInfo = toolsTitle ? toolsTitle.nextElementSibling : null;
if (toolsInfo) {
const refreshLink = document.createElement("span");
refreshLink.textContent = " [Refresh list]";
refreshLink.className = "tools-refresh-link";
refreshLink.addEventListener("click", function(e) {
e.preventDefault();
document.querySelector("#tools-refresh-btn").click();
});
toolsInfo.appendChild(refreshLink);
} }
//------------------------------------------------ //------------------------------------------------
@ -552,38 +560,6 @@ document.querySelectorAll(".focus-on-chat-input").forEach(element => {
}); });
}); });
//------------------------------------------------
// "New chat" hover menu with incognito option
//------------------------------------------------
(function() {
const newChatBtn = document.getElementById("new-chat-btn");
const wrapper = document.createElement("div");
wrapper.id = "new-chat-wrapper";
newChatBtn.replaceWith(wrapper);
wrapper.appendChild(newChatBtn);
const arrow = document.createElement("span");
arrow.className = "new-chat-arrow";
arrow.textContent = "\u25BE";
const menu = document.createElement("div");
menu.className = "new-chat-menu";
const option = document.createElement("div");
option.className = "new-chat-menu-item";
option.textContent = "Incognito chat";
menu.appendChild(option);
arrow.appendChild(menu);
wrapper.appendChild(arrow);
option.addEventListener("click", function(e) {
e.stopPropagation();
document.querySelector("#incognito-chat-btn").click();
});
})();
//------------------------------------------------ //------------------------------------------------
// Fix a border around the "past chats" menu // Fix a border around the "past chats" menu
//------------------------------------------------ //------------------------------------------------
@ -1113,13 +1089,15 @@ document.fonts.addEventListener("loadingdone", (event) => {
const currentHeight = chatInputRow.offsetHeight; const currentHeight = chatInputRow.offsetHeight;
const heightDifference = currentHeight - originalHeight; const heightDifference = currentHeight - originalHeight;
chatParent.style.marginBottom = `${originalMarginBottom + heightDifference}px`; chatParent.style.marginBottom = `${originalMarginBottom + heightDifference}px`;
if (!window.isScrolled) {
chatParent.scrollTop = chatParent.scrollHeight - chatParent.clientHeight;
}
} }
// Watch for size changes that affect height // Watch for changes that might affect height
new ResizeObserver(updateMargin).observe(chatInputRow); const observer = new MutationObserver(updateMargin);
observer.observe(chatInputRow, {
childList: true,
subtree: true,
attributes: true
});
// Also listen for window resize // Also listen for window resize
window.addEventListener("resize", updateMargin); window.addEventListener("resize", updateMargin);

View file

@ -5,6 +5,9 @@ from modules.logging_colors import logger
def add_lora_to_model(lora_names): def add_lora_to_model(lora_names):
if shared.model.__class__.__name__ in ['Exllamav2Model', 'Exllamav2HF'] or shared.args.loader in ['ExLlamav2', 'ExLlamav2_HF']:
add_lora_exllamav2(lora_names)
else:
add_lora_transformers(lora_names) add_lora_transformers(lora_names)
@ -16,6 +19,32 @@ def get_lora_path(lora_name):
return Path(f"{shared.args.lora_dir}/{lora_name}") return Path(f"{shared.args.lora_dir}/{lora_name}")
def add_lora_exllamav2(lora_names):
from exllamav2 import ExLlamaV2Lora
if isinstance(shared.model.loras, list):
for lora in shared.model.loras:
lora.unload()
if len(lora_names) > 0:
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
shared.model.loras = []
for lora_name in lora_names:
lora_path = get_lora_path(lora_name)
if shared.model.__class__.__name__ == 'Exllamav2Model':
lora = ExLlamaV2Lora.from_directory(shared.model.model, str(lora_path))
else:
lora = ExLlamaV2Lora.from_directory(shared.model.ex_model, str(lora_path))
shared.model.loras.append(lora)
shared.lora_names = lora_names
else:
shared.lora_names = []
shared.model.loras = None
def add_lora_transformers(lora_names): def add_lora_transformers(lora_names):
from peft import PeftModel from peft import PeftModel
@ -48,7 +77,9 @@ def add_lora_transformers(lora_names):
if len(lora_names) > 0: if len(lora_names) > 0:
params = {} params = {}
if not shared.args.cpu: if not shared.args.cpu:
if not shared.args.load_in_4bit and not shared.args.load_in_8bit: if shared.args.load_in_4bit or shared.args.load_in_8bit:
params['peft_type'] = shared.model.dtype
else:
params['dtype'] = shared.model.dtype params['dtype'] = shared.model.dtype
if hasattr(shared.model, "hf_device_map"): if hasattr(shared.model, "hf_device_map"):
params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()} params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}

96
modules/block_requests.py Normal file
View file

@ -0,0 +1,96 @@
import builtins
import io
import re
import requests
from modules import shared, ui
from modules.logging_colors import logger
original_open = open
original_get = requests.get
original_print = print
class RequestBlocker:
def __enter__(self):
requests.get = my_get
def __exit__(self, exc_type, exc_value, traceback):
requests.get = original_get
class OpenMonkeyPatch:
def __enter__(self):
builtins.open = my_open
builtins.print = my_print
def __exit__(self, exc_type, exc_value, traceback):
builtins.open = original_open
builtins.print = original_print
def my_get(url, **kwargs):
logger.info('Unwanted HTTP request redirected to localhost :)')
kwargs.setdefault('allow_redirects', True)
return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
def my_open(*args, **kwargs):
filename = str(args[0])
if filename.endswith(('index.html', 'share.html')):
with original_open(*args, **kwargs) as f:
file_contents = f.read()
if len(args) > 1 and args[1] == 'rb':
file_contents = file_contents.decode('utf-8')
file_contents = file_contents.replace('\t\t<script\n\t\t\tsrc="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"\n\t\t\tasync\n\t\t></script>', '')
file_contents = file_contents.replace('cdnjs.cloudflare.com', '127.0.0.1')
file_contents = file_contents.replace(
'</head>',
'\n <link rel="preload" href="file/css/Inter/Inter-VariableFont_opsz,wght.ttf" as="font" type="font/ttf" crossorigin>'
'\n <link rel="preload" href="file/css/Inter/Inter-Italic-VariableFont_opsz,wght.ttf" as="font" type="font/ttf" crossorigin>'
'\n <link rel="preload" href="file/css/NotoSans/NotoSans-Medium.woff2" as="font" type="font/woff2" crossorigin>'
'\n <link rel="preload" href="file/css/NotoSans/NotoSans-MediumItalic.woff2" as="font" type="font/woff2" crossorigin>'
'\n <link rel="preload" href="file/css/NotoSans/NotoSans-Bold.woff2" as="font" type="font/woff2" crossorigin>'
'\n <script src="file/js/katex/katex.min.js"></script>'
'\n <script src="file/js/katex/auto-render.min.js"></script>'
'\n <script src="file/js/highlightjs/highlight.min.js"></script>'
'\n <script src="file/js/highlightjs/highlightjs-copy.min.js"></script>'
'\n <script src="file/js/morphdom/morphdom-umd.min.js"></script>'
f'\n <link id="highlight-css" rel="stylesheet" href="file/css/highlightjs/{"github-dark" if shared.settings["dark_theme"] else "github"}.min.css">'
'\n <script>hljs.addPlugin(new CopyButtonPlugin());</script>'
f'\n <script>{ui.global_scope_js}</script>'
'\n </head>'
)
file_contents = re.sub(
r'@media \(prefers-color-scheme: dark\) \{\s*body \{([^}]*)\}\s*\}',
r'body.dark {\1}',
file_contents,
flags=re.DOTALL
)
if len(args) > 1 and args[1] == 'rb':
file_contents = file_contents.encode('utf-8')
return io.BytesIO(file_contents)
else:
return io.StringIO(file_contents)
else:
return original_open(*args, **kwargs)
def my_print(*args, **kwargs):
if len(args) > 0 and 'To create a public link, set `share=True`' in args[0]:
return
else:
if len(args) > 0 and 'Running on local URL' in args[0]:
args = list(args)
args[0] = f"\n{args[0].strip()}\n"
args = tuple(args)
original_print(*args, **kwargs)

View file

@ -37,7 +37,7 @@ class Iteratorize:
ret = self.mfunc(callback=_callback, *args, **self.kwargs) ret = self.mfunc(callback=_callback, *args, **self.kwargs)
except StopNowException: except StopNowException:
pass pass
except Exception: except:
traceback.print_exc() traceback.print_exc()
pass pass

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,74 @@
def generate_ds_config(ds_bf16, train_batch_size, nvme_offload_dir):
'''
DeepSpeed configuration
https://huggingface.co/docs/transformers/main_classes/deepspeed
'''
if nvme_offload_dir:
ds_config = {
"fp16": {
"enabled": not ds_bf16,
},
"bf16": {
"enabled": ds_bf16,
},
"zero_optimization": {
"stage": 3,
"offload_param": {
"device": "nvme",
"nvme_path": nvme_offload_dir,
"pin_memory": True,
"buffer_count": 5,
"buffer_size": 1e9,
"max_in_cpu": 1e9
},
"overlap_comm": True,
"reduce_bucket_size": "auto",
"contiguous_gradients": True,
"sub_group_size": 1e8,
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": "auto",
"stage3_max_reuse_distance": "auto",
},
"aio": {
"block_size": 262144,
"queue_depth": 32,
"thread_count": 1,
"single_submit": False,
"overlap_events": True
},
"steps_per_print": 2000,
"train_batch_size": train_batch_size,
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": False
}
else:
ds_config = {
"fp16": {
"enabled": not ds_bf16,
},
"bf16": {
"enabled": ds_bf16,
},
"zero_optimization": {
"stage": 3,
"offload_param": {
"device": "cpu",
"pin_memory": True
},
"overlap_comm": True,
"contiguous_gradients": True,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": "auto",
"stage3_max_reuse_distance": "auto",
},
"steps_per_print": 2000,
"train_batch_size": train_batch_size,
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": False
}
return ds_config

View file

@ -12,8 +12,8 @@ from modules.text_generation import encode
def load_past_evaluations(): def load_past_evaluations():
if (shared.user_data_dir / 'logs' / 'evaluations.csv').exists(): if Path('user_data/logs/evaluations.csv').exists():
df = pd.read_csv(shared.user_data_dir / 'logs' / 'evaluations.csv', dtype=str) df = pd.read_csv(Path('user_data/logs/evaluations.csv'), dtype=str)
df['Perplexity'] = pd.to_numeric(df['Perplexity']) df['Perplexity'] = pd.to_numeric(df['Perplexity'])
return df return df
else: else:
@ -26,7 +26,7 @@ past_evaluations = load_past_evaluations()
def save_past_evaluations(df): def save_past_evaluations(df):
global past_evaluations global past_evaluations
past_evaluations = df past_evaluations = df
filepath = shared.user_data_dir / 'logs' / 'evaluations.csv' filepath = Path('user_data/logs/evaluations.csv')
filepath.parent.mkdir(parents=True, exist_ok=True) filepath.parent.mkdir(parents=True, exist_ok=True)
df.to_csv(filepath, index=False) df.to_csv(filepath, index=False)
@ -46,6 +46,10 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
logger.error("Perplexity evaluation is not implemented for the llama.cpp loader.") logger.error("Perplexity evaluation is not implemented for the llama.cpp loader.")
raise ValueError raise ValueError
if shared.args.loader == "ExLlamav2":
logger.error("ExLlamav2_HF is required for perplexity evaluation with EXL2 models. Please reload the model with ExLlamav2_HF instead of ExLlamav2.")
raise ValueError
if not shared.args.no_use_fast: if not shared.args.no_use_fast:
logger.warning("--no_use_fast is not set. If tokenizing the input dataset takes a long time, try reloading the model with that option set/checked.") logger.warning("--no_use_fast is not set. If tokenizing the input dataset takes a long time, try reloading the model with that option set/checked.")
@ -65,7 +69,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
data = load_dataset('ptb_text_only', 'penn_treebank', split='test') data = load_dataset('ptb_text_only', 'penn_treebank', split='test')
text = " ".join(data['sentence']) text = " ".join(data['sentence'])
else: else:
with open(shared.user_data_dir / 'training' / 'datasets' / f'{input_dataset}.txt', 'r', encoding='utf-8') as f: with open(Path(f'user_data/training/datasets/{input_dataset}.txt'), 'r', encoding='utf-8') as f:
text = f.read() text = f.read()
for model in models: for model in models:
@ -82,7 +86,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
update_model_parameters(model_settings) # hijacking the command-line arguments update_model_parameters(model_settings) # hijacking the command-line arguments
unload_model() unload_model()
shared.model, shared.tokenizer = load_model(model) shared.model, shared.tokenizer = load_model(model)
except Exception: except:
cumulative_log += f"Failed to load `{model}`. Moving on.\n\n" cumulative_log += f"Failed to load `{model}`. Moving on.\n\n"
yield cumulative_log yield cumulative_log
continue continue

247
modules/exllamav2.py Normal file
View file

@ -0,0 +1,247 @@
import json
import traceback
from pathlib import Path
import torch
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Cache_Q4,
ExLlamaV2Cache_Q6,
ExLlamaV2Cache_Q8,
ExLlamaV2Cache_TP,
ExLlamaV2Config,
ExLlamaV2Tokenizer
)
from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator
from modules import shared
from modules.logging_colors import logger
from modules.text_generation import get_max_prompt_length
try:
import flash_attn
except Exception:
logger.warning('Failed to load flash-attention due to the following error:\n')
traceback.print_exc()
class Exllamav2Model:
def __init__(self):
pass
@classmethod
def from_pretrained(self, path_to_model):
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
config = ExLlamaV2Config()
config.model_dir = str(path_to_model)
config.prepare()
config.max_seq_len = shared.args.ctx_size
config.scale_pos_emb = shared.args.compress_pos_emb
config.scale_alpha_value = shared.args.alpha_value
config.no_flash_attn = shared.args.no_flash_attn
config.no_xformers = shared.args.no_xformers
config.no_sdpa = shared.args.no_sdpa
config.num_experts_per_token = int(shared.args.num_experts_per_token)
model = ExLlamaV2(config)
split = None
if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
if shared.args.enable_tp:
model.load_tp(split)
elif not shared.args.autosplit:
model.load(split)
# Determine the correct cache type
kv_cache_type = shared.args.cache_type.lower()
if kv_cache_type == 'fp16':
cache_type = ExLlamaV2Cache
elif kv_cache_type == 'fp8':
cache_type = ExLlamaV2Cache_8bit
elif kv_cache_type == 'q8':
cache_type = ExLlamaV2Cache_Q8
elif kv_cache_type == 'q6':
cache_type = ExLlamaV2Cache_Q6
elif kv_cache_type == 'q4':
cache_type = ExLlamaV2Cache_Q4
else:
raise ValueError(f"Invalid cache type for ExLlamaV2: {cache_type}. Valid options are: fp16, fp8, q8, q6, q4.")
# Use TP if specified
if shared.args.enable_tp:
cache = ExLlamaV2Cache_TP(model, base=cache_type)
else:
cache = cache_type(model, lazy=shared.args.autosplit)
if shared.args.autosplit and not shared.args.enable_tp:
model.load_autosplit(cache)
tokenizer = ExLlamaV2Tokenizer(config)
# Initialize draft model for speculative decoding
draft_model = None
draft_cache = None
if shared.args.model_draft and shared.args.model_draft.lower() not in ["none", ""]:
logger.info(f"Loading draft model for speculative decoding: {shared.args.model_draft}")
# Find the draft model path
draft_path = Path(shared.args.model_draft)
if not draft_path.exists():
draft_path = Path(f'{shared.args.model_dir}') / Path(shared.args.model_draft)
draft_config = ExLlamaV2Config()
draft_config.model_dir = str(draft_path)
draft_config.prepare()
draft_config.arch_compat_overrides()
# Set context size for draft model
if shared.args.ctx_size_draft > 0:
draft_config.max_seq_len = shared.args.ctx_size_draft
else:
draft_config.max_seq_len = config.max_seq_len
draft_model = ExLlamaV2(draft_config)
draft_cache = cache_type(draft_model, lazy=True)
draft_model.load_autosplit(draft_cache)
logger.info(f"Draft model loaded successfully with max_draft={shared.args.draft_max}")
generator = ExLlamaV2StreamingGenerator(
model,
cache,
tokenizer,
draft_model=draft_model,
draft_cache=draft_cache,
num_speculative_tokens=shared.args.draft_max if draft_model is not None else 0
)
result = self()
result.model = model
result.cache = cache
result.tokenizer = tokenizer
result.generator = generator
result.loras = None
result.draft_model = draft_model
result.draft_cache = draft_cache
return result, result
def encode(self, string, **kwargs):
add_bos = kwargs.pop('add_bos', True)
return self.tokenizer.encode(string, add_bos=add_bos, encode_special_tokens=True, **kwargs)
def decode(self, ids, **kwargs):
if isinstance(ids, list):
ids = torch.tensor([ids])
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
ids = ids.view(1, -1)
return self.tokenizer.decode(ids, decode_special_tokens=True)[0]
def get_logits(self, token_ids, **kwargs):
self.cache.current_seq_len = 0
if token_ids.shape[-1] > 1:
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True, loras=self.loras)
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, loras=self.loras, **kwargs).float().cpu()
def generate_with_streaming(self, prompt, state):
settings = ExLlamaV2Sampler.Settings()
settings.token_repetition_penalty = state['repetition_penalty']
settings.token_repetition_range = -1 if state['repetition_penalty_range'] <= 0 else state['repetition_penalty_range']
settings.token_frequency_penalty = state['frequency_penalty']
settings.token_presence_penalty = state['presence_penalty']
settings.temperature = state['temperature']
settings.smoothing_factor = state['smoothing_factor']
settings.min_temp = state['dynatemp_low'] if state['dynamic_temperature'] else 0
settings.max_temp = state['dynatemp_high'] if state['dynamic_temperature'] else 0
settings.temp_exponent = state['dynatemp_exponent']
settings.top_k = state['top_k']
settings.top_p = state['top_p']
settings.top_a = state['top_a']
settings.min_p = state['min_p']
settings.tfs = state['tfs']
settings.typical = state['typical_p']
settings.temperature_last = state['temperature_last']
settings.mirostat = state['mirostat_mode'] == 2
settings.mirostat_tau = state['mirostat_tau']
settings.mirostat_eta = state['mirostat_eta']
if state['ban_eos_token']:
settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
if state['custom_token_bans']:
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
if len(to_ban) > 0:
settings.disallow_tokens(self.tokenizer, to_ban)
settings.dry_allowed_length = state['dry_allowed_length']
settings.dry_base = state['dry_base']
settings.dry_multiplier = state['dry_multiplier']
# Dry sequence breakers processing
if state['dry_multiplier'] > 0 and state['dry_sequence_breakers']:
dry_sequence_breakers = state['dry_sequence_breakers']
# Support both JSON array notation and comma-separated strings.
if not dry_sequence_breakers.startswith("["):
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
sequence_breaker_strings = json.loads(dry_sequence_breakers)
# Prefix with 'a' to get the correct encoding of the token at the end of a text.
sequence_breakers = {
self.encode(f"a{s}")[0, -1].item() for s in sequence_breaker_strings
}
settings.dry_sequence_breakers = sequence_breakers
settings.xtc_probability = state['xtc_probability']
settings.xtc_threshold = state['xtc_threshold']
ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True)
ids = ids[:, -get_max_prompt_length(state):]
if state['auto_max_new_tokens']:
max_new_tokens = state['truncation_length'] - ids.shape[-1]
else:
max_new_tokens = state['max_new_tokens']
# Reset speculative decoding stats if using a draft model
if hasattr(self, 'draft_model') and self.draft_model is not None:
self.generator.reset_sd_stats()
self.generator.begin_stream(ids, settings, loras=self.loras)
decoded_text = ''
for i in range(max_new_tokens):
chunk, eos, _ = self.generator.stream()
if eos or shared.stop_everything:
break
decoded_text += chunk
yield decoded_text
# Log speculative decoding stats if using draft model
if hasattr(self, 'draft_model') and self.draft_model is not None:
efficiency, accuracy, total_tokens, total_draft_tokens, accepted_draft_tokens = self.generator.get_sd_stats()
logger.info(f"Speculative decoding: accepted={accepted_draft_tokens}/{total_draft_tokens} tokens")
def generate(self, prompt, state):
output = ''
for output in self.generate_with_streaming(prompt, state):
pass
return output

203
modules/exllamav2_hf.py Normal file
View file

@ -0,0 +1,203 @@
import os
import traceback
from pathlib import Path
from typing import Any, Dict, Optional, Union
import torch
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Cache_Q4,
ExLlamaV2Cache_Q6,
ExLlamaV2Cache_Q8,
ExLlamaV2Cache_TP,
ExLlamaV2Config
)
from torch.nn import CrossEntropyLoss
from transformers import (
GenerationConfig,
GenerationMixin,
PretrainedConfig,
PreTrainedModel
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from modules import shared
from modules.logging_colors import logger
try:
import flash_attn
except Exception:
logger.warning('Failed to load flash-attention due to the following error:\n')
traceback.print_exc()
class Exllamav2HF(PreTrainedModel, GenerationMixin):
def __init__(self, config: ExLlamaV2Config):
hf_config = PretrainedConfig.from_pretrained(config.model_dir)
super().__init__(hf_config)
self.ex_config = config
self.loras = None
self.generation_config = GenerationConfig()
self.ex_model = ExLlamaV2(config)
split = None
if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
if shared.args.enable_tp:
self.ex_model.load_tp(split)
elif not shared.args.autosplit:
self.ex_model.load(split)
# Determine the correct cache type
kv_cache_type = shared.args.cache_type.lower()
if kv_cache_type == 'fp16':
cache_type = ExLlamaV2Cache
elif kv_cache_type == 'fp8':
cache_type = ExLlamaV2Cache_8bit
elif kv_cache_type == 'q8':
cache_type = ExLlamaV2Cache_Q8
elif kv_cache_type == 'q6':
cache_type = ExLlamaV2Cache_Q6
elif kv_cache_type == 'q4':
cache_type = ExLlamaV2Cache_Q4
else:
raise ValueError(f"Invalid cache type for ExLlamaV2: {kv_cache_type}. Valid options are: fp16, fp8, q8, q6, q4.")
# Use TP if specified
if shared.args.enable_tp:
self.ex_cache = ExLlamaV2Cache_TP(self.ex_model, base=cache_type)
else:
self.ex_cache = cache_type(self.ex_model, lazy=shared.args.autosplit)
if shared.args.autosplit and not shared.args.enable_tp:
self.ex_model.load_autosplit(self.ex_cache)
self.past_seq = None
if shared.args.cfg_cache:
if shared.args.enable_tp:
self.ex_cache_negative = ExLlamaV2Cache_TP(self.ex_model, base=cache_type)
else:
self.ex_cache_negative = cache_type(self.ex_model, lazy=shared.args.autosplit)
self.past_seq_negative = None
def _validate_model_class(self):
pass
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
pass
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {'input_ids': input_ids, **kwargs}
@property
def device(self) -> torch.device:
return torch.device(0)
def __call__(self, *args, **kwargs):
use_cache = kwargs.get('use_cache', True)
labels = kwargs.get('labels', None)
past_key_values = kwargs.get('past_key_values', None)
if len(args) > 0:
if not shared.args.cfg_cache:
logger.error("Please enable the cfg-cache option to use CFG with ExLlamav2_HF.")
return
input_ids = args[0]
is_negative = True
past_seq = self.past_seq_negative
ex_cache = self.ex_cache_negative
else:
input_ids = kwargs['input_ids']
is_negative = False
past_seq = self.past_seq
ex_cache = self.ex_cache
seq = input_ids[0].tolist()
if is_negative and past_key_values is not None:
seq = past_key_values + seq
seq_tensor = torch.tensor(seq)
reset = True
# Make the forward call
if labels is None:
if past_seq is not None:
min_length = min(past_seq.shape[0], seq_tensor.shape[0])
indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
if len(indices) > 0:
longest_prefix = indices[0].item()
else:
longest_prefix = min_length
if longest_prefix > 0:
reset = False
ex_cache.current_seq_len = longest_prefix
if len(seq_tensor) - longest_prefix > 1:
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
elif len(seq_tensor) == longest_prefix:
# Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one,
# because we feed input_ids[-1] to forward() below, but that last token is already in the cache!
ex_cache.current_seq_len -= 1
if reset:
ex_cache.current_seq_len = 0
if len(seq_tensor) > 1:
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device).float()
else:
ex_cache.current_seq_len = 0
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras).float()
if is_negative:
self.past_seq_negative = seq_tensor
else:
self.past_seq = seq_tensor
if torch.cuda.is_available():
torch.cuda.synchronize()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, logits.shape[-1])
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported"
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
pretrained_model_name_or_path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path)
config = ExLlamaV2Config()
config.model_dir = str(pretrained_model_name_or_path)
config.prepare()
config.max_seq_len = shared.args.ctx_size
config.scale_pos_emb = shared.args.compress_pos_emb
config.scale_alpha_value = shared.args.alpha_value
config.no_flash_attn = shared.args.no_flash_attn
config.no_xformers = shared.args.no_xformers
config.no_sdpa = shared.args.no_sdpa
config.num_experts_per_token = int(shared.args.num_experts_per_token)
return Exllamav2HF(config)

View file

@ -1,6 +1,3 @@
import math
import queue
import threading
import traceback import traceback
from pathlib import Path from pathlib import Path
from typing import Any, List, Tuple from typing import Any, List, Tuple
@ -10,10 +7,8 @@ import torch
from exllamav3 import Cache, Config, Generator, Model, Tokenizer from exllamav3 import Cache, Config, Generator, Model, Tokenizer
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
from exllamav3.generator import Job from exllamav3.generator import Job
from exllamav3.generator.filter import Filter
from exllamav3.generator.sampler import ( from exllamav3.generator.sampler import (
CustomSampler, CustomSampler,
SS_AdaptiveP,
SS_Argmax, SS_Argmax,
SS_MinP, SS_MinP,
SS_PresFreqP, SS_PresFreqP,
@ -38,95 +33,10 @@ except Exception:
traceback.print_exc() traceback.print_exc()
class LogitBiasFilter(Filter):
"""Filter subclass that applies a static additive logit bias mask."""
def __init__(self, tokenizer, logit_bias_dict):
super().__init__(tokenizer=tokenizer, trigger_token=None, prefix_str=None, eos_after_completed=False)
self.logit_bias_dict = logit_bias_dict
self._mask = None
def reset(self): pass
def accept_token(self, token): pass
def is_completed(self): return False
def use_background_worker(self): return False
def get_next_logit_mask(self):
if self._mask is None:
self._mask = torch.zeros((1, self.vocab_size), dtype=self.logits_dtype)
for token_id_str, bias in self.logit_bias_dict.items():
token_id = int(token_id_str)
if 0 <= token_id < self.vocab_size:
self._mask[0, token_id] = bias
return self._mask
class ConcurrentGenerator:
def __init__(self, generator):
self.generator = generator
self.lock = threading.Lock()
self.job_queues = {}
self.active = True
self.has_jobs = threading.Event()
self.thread = threading.Thread(target=self._iterate_loop, daemon=True)
self.thread.start()
def _iterate_loop(self):
while self.active:
self.has_jobs.wait(timeout=0.5)
with self.lock:
if not self.job_queues:
self.has_jobs.clear()
continue
try:
results = self.generator.iterate()
except Exception:
logger.error("Exception in ConcurrentGenerator iterate loop:\n" + traceback.format_exc())
for q in self.job_queues.values():
q.put(None)
self.job_queues.clear()
self.generator.clear_queue()
self.has_jobs.clear()
continue
for result in results:
job = result["job"]
q = self.job_queues.get(job)
if q:
q.put(result)
if result.get("eos"):
self.job_queues.pop(job, None)
if not self.job_queues:
self.has_jobs.clear()
def submit(self, job) -> queue.Queue:
q = queue.Queue()
with self.lock:
self.job_queues[job] = q
self.generator.enqueue(job)
self.has_jobs.set()
return q
def cancel(self, job):
with self.lock:
if job in self.job_queues:
self.generator.cancel(job)
self.job_queues[job].put(None)
del self.job_queues[job]
def stop(self):
self.active = False
self.has_jobs.set()
self.thread.join(timeout=5)
class Exllamav3Model: class Exllamav3Model:
def __init__(self): def __init__(self):
pass pass
@property
def device(self) -> torch.device:
return torch.device(0)
@classmethod @classmethod
def from_pretrained(cls, path_to_model): def from_pretrained(cls, path_to_model):
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model) path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
@ -148,7 +58,7 @@ class Exllamav3Model:
logger.warning(f"max_num_tokens must be a multiple of 256. Adjusting from {max_tokens} to {adjusted_tokens}") logger.warning(f"max_num_tokens must be a multiple of 256. Adjusting from {max_tokens} to {adjusted_tokens}")
max_tokens = adjusted_tokens max_tokens = adjusted_tokens
# Parse cache type # Parse cache type (ExLlamaV2 pattern)
cache_type = shared.args.cache_type.lower() cache_type = shared.args.cache_type.lower()
cache_kwargs = {} cache_kwargs = {}
if cache_type == 'fp16': if cache_type == 'fp16':
@ -187,21 +97,8 @@ class Exllamav3Model:
load_params['tensor_p'] = True load_params['tensor_p'] = True
load_params['tp_backend'] = shared.args.tp_backend load_params['tp_backend'] = shared.args.tp_backend
# Load vision and draft before the main model so autosplit model.load(**load_params)
# accounts for their VRAM usage. tokenizer = Tokenizer.from_config(config)
# Load vision model component (ExLlamaV3 native)
vision_model = None
if "vision_config" in config.config_dict:
logger.info("Vision component detected in model config. Attempting to load...")
try:
vision_model = Model.from_config(config, component="vision")
vision_model.load(progressbar=True)
logger.info("Vision model loaded successfully.")
except Exception as e:
logger.warning(f"Vision model loading failed (multimodal disabled): {e}")
else:
logger.info("No vision component in model config. Skipping multimodal setup.")
# Initialize draft model for speculative decoding # Initialize draft model for speculative decoding
draft_model = None draft_model = None
@ -217,8 +114,23 @@ class Exllamav3Model:
logger.warning(f"Draft model not found at {draft_path}, speculative decoding disabled.") logger.warning(f"Draft model not found at {draft_path}, speculative decoding disabled.")
else: else:
draft_config = Config.from_directory(str(draft_path)) draft_config = Config.from_directory(str(draft_path))
# Set context size for draft model with 256-multiple validation
if shared.args.ctx_size_draft > 0:
draft_max_tokens = shared.args.ctx_size_draft
else:
draft_max_tokens = shared.args.ctx_size
# Validate draft model context size is a multiple of 256
if draft_max_tokens % 256 != 0:
adjusted_draft_tokens = ((draft_max_tokens // 256) + 1) * 256
logger.warning(f"Draft model max_num_tokens must be a multiple of 256. Adjusting from {draft_max_tokens} to {adjusted_draft_tokens}")
draft_max_tokens = adjusted_draft_tokens
draft_config.max_seq_len = draft_max_tokens
draft_model = Model.from_config(draft_config) draft_model = Model.from_config(draft_config)
draft_cache = Cache(draft_model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs) draft_cache = Cache(draft_model, max_num_tokens=draft_max_tokens, layer_type=layer_type, **cache_kwargs)
draft_load_params = {'progressbar': True} draft_load_params = {'progressbar': True}
if split: if split:
@ -227,9 +139,18 @@ class Exllamav3Model:
draft_model.load(**draft_load_params) draft_model.load(**draft_load_params)
logger.info(f"Draft model loaded successfully. Max speculative tokens: {shared.args.draft_max}") logger.info(f"Draft model loaded successfully. Max speculative tokens: {shared.args.draft_max}")
# Load main model last # Load vision model component (ExLlamaV3 native)
model.load(**load_params) vision_model = None
tokenizer = Tokenizer.from_config(config) if "vision_config" in config.config_dict:
logger.info("Vision component detected in model config. Attempting to load...")
try:
vision_model = Model.from_config(config, component="vision")
vision_model.load(progressbar=True)
logger.info("Vision model loaded successfully.")
except Exception as e:
logger.warning(f"Vision model loading failed (multimodal disabled): {e}")
else:
logger.info("No vision component in model config. Skipping multimodal setup.")
generator = Generator( generator = Generator(
model=model, model=model,
@ -237,7 +158,7 @@ class Exllamav3Model:
tokenizer=tokenizer, tokenizer=tokenizer,
draft_model=draft_model, draft_model=draft_model,
draft_cache=draft_cache, draft_cache=draft_cache,
num_draft_tokens=shared.args.draft_max if draft_model is not None else 0, num_speculative_tokens=shared.args.draft_max if draft_model is not None else 0,
) )
result = cls() result = cls()
@ -245,7 +166,6 @@ class Exllamav3Model:
result.cache = cache result.cache = cache
result.tokenizer = tokenizer result.tokenizer = tokenizer
result.generator = generator result.generator = generator
result.parallel_generator = ConcurrentGenerator(generator)
result.config = config result.config = config
result.max_tokens = max_tokens result.max_tokens = max_tokens
result.vision_model = vision_model result.vision_model = vision_model
@ -366,16 +286,11 @@ class Exllamav3Model:
# 3. Get the priority list and handle temperature_last # 3. Get the priority list and handle temperature_last
default_priority = ['repetition_penalty', 'presence_frequency_penalty', 'top_k', 'top_p', 'min_p', 'temperature'] default_priority = ['repetition_penalty', 'presence_frequency_penalty', 'top_k', 'top_p', 'min_p', 'temperature']
sampler_priority = list(state.get('sampler_priority') or default_priority) sampler_priority = state.get('sampler_priority') or default_priority
if state['temperature_last'] and 'temperature' in sampler_priority: if state['temperature_last'] and 'temperature' in sampler_priority:
sampler_priority.append(sampler_priority.pop(sampler_priority.index('temperature'))) sampler_priority.append(sampler_priority.pop(sampler_priority.index('temperature')))
# The preset system uses separate 'presence_penalty' and
# 'frequency_penalty', but ExLlamaV3 has a single combined
# SS_PresFreqP sampler. Normalize to the combined name.
sampler_priority = ['presence_frequency_penalty' if x in ('presence_penalty', 'frequency_penalty') else x for x in sampler_priority]
# 4. Sort the unordered list based on the priority list # 4. Sort the unordered list based on the priority list
def custom_sort_key(sampler_obj): def custom_sort_key(sampler_obj):
class_name = sampler_obj.__class__.__name__ class_name = sampler_obj.__class__.__name__
@ -387,11 +302,7 @@ class Exllamav3Model:
ordered_samplers = sorted(unordered_samplers, key=custom_sort_key) ordered_samplers = sorted(unordered_samplers, key=custom_sort_key)
# 5. Add the final sampling stage and build the sampler # 5. Add the final sampling stage and build the sampler
if state.get('adaptive_target', 0) > 0:
ordered_samplers.append(SS_AdaptiveP(state['adaptive_target'], state['adaptive_decay']))
else:
ordered_samplers.append(SS_Sample()) ordered_samplers.append(SS_Sample())
sampler = CustomSampler(ordered_samplers) sampler = CustomSampler(ordered_samplers)
# Encode prompt with embeddings (ExLlamaV3-specific) # Encode prompt with embeddings (ExLlamaV3-specific)
@ -412,86 +323,43 @@ class Exllamav3Model:
else: else:
max_new_tokens = state['max_new_tokens'] max_new_tokens = state['max_new_tokens']
# Use full EOS token list from config (may contain multiple IDs) # Get stop conditions
stop_conditions = [] stop_conditions = []
if not state['ban_eos_token']: if not state['ban_eos_token']:
for eos_id in self.config.eos_token_id_list: if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
if eos_id is not None: stop_conditions.append(self.tokenizer.eos_token_id)
stop_conditions.append(eos_id)
# Build filters for logit_bias (OpenAI API)
filters = []
logit_bias = state.get('logit_bias')
if logit_bias:
filters.append(LogitBiasFilter(self.tokenizer, logit_bias))
# Logprobs support (OpenAI API)
logprobs = state.get('logprobs', 0) or 0
return_top_tokens = logprobs if logprobs > 0 else 0
seed = state.get('seed', -1)
job = Job( job = Job(
input_ids=input_ids, input_ids=input_ids,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
decode_special_tokens=not state['skip_special_tokens'], decode_special_tokens=not state['skip_special_tokens'],
embeddings=image_embeddings if image_embeddings else None, embeddings=image_embeddings if image_embeddings else None,
sampler=sampler, sampler=sampler,
seed=seed if seed >= 0 else None,
stop_conditions=stop_conditions if stop_conditions else None, stop_conditions=stop_conditions if stop_conditions else None,
filters=filters if filters else None,
return_top_tokens=return_top_tokens,
return_probs=return_top_tokens > 0,
) )
# Stream generation # Stream generation
self.generator.enqueue(job)
response_text = "" response_text = ""
stop_event = state.get('stop_event')
self.last_completion_probabilities = []
result_queue = self.parallel_generator.submit(job)
try: try:
while True: while self.generator.num_remaining_jobs():
if shared.stop_everything or (stop_event and stop_event.is_set()): if shared.stop_everything:
break break
try:
result = result_queue.get(timeout=0.1) results = self.generator.iterate()
except queue.Empty: for result in results:
continue if "eos" in result and result["eos"]:
if result is None or result.get("eos"):
# Capture logprobs from the final eos result too
if result is not None and return_top_tokens > 0:
self._capture_logprobs(result)
break break
chunk = result.get("text", "") chunk = result.get("text", "")
# Capture logprobs from streaming results
if return_top_tokens > 0:
self._capture_logprobs(result)
if chunk: if chunk:
response_text += chunk response_text += chunk
yield response_text yield response_text
finally: finally:
self.parallel_generator.cancel(job) self.generator.clear_queue()
def _capture_logprobs(self, result):
"""Convert ExLlamav3 top-k token data to the shared logprobs format."""
top_k_tokens = result.get("top_k_tokens")
top_k_probs = result.get("top_k_probs")
if top_k_tokens is None or top_k_probs is None:
return
id_to_piece = self.tokenizer.get_id_to_piece_list(True)
# top_k_tokens shape: (batch, seq_len, k), top_k_probs same
for seq_idx in range(top_k_tokens.shape[1]):
entry = {"top_logprobs": []}
for k_idx in range(top_k_tokens.shape[2]):
token_id = top_k_tokens[0, seq_idx, k_idx].item()
prob = top_k_probs[0, seq_idx, k_idx].item()
token_str = id_to_piece[token_id] if token_id < len(id_to_piece) else f"<{token_id}>"
logprob = math.log(prob) if prob > 0 else float("-inf")
entry["top_logprobs"].append({"token": token_str, "logprob": logprob})
self.last_completion_probabilities.append(entry)
def generate(self, prompt, state): def generate(self, prompt, state):
output = "" output = ""
@ -554,13 +422,6 @@ class Exllamav3Model:
def unload(self): def unload(self):
logger.info("Unloading ExLlamaV3 model components...") logger.info("Unloading ExLlamaV3 model components...")
if hasattr(self, 'parallel_generator') and self.parallel_generator is not None:
try:
self.parallel_generator.stop()
except Exception as e:
logger.warning(f"Error stopping parallel generator: {e}")
self.parallel_generator = None
if hasattr(self, 'vision_model') and self.vision_model is not None: if hasattr(self, 'vision_model') and self.vision_model is not None:
try: try:
del self.vision_model del self.vision_model

View file

@ -84,12 +84,6 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
self.ex_model.load(**load_params) self.ex_model.load(**load_params)
self.past_seq = None self.past_seq = None
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.layer_type = layer_type
self.cache_kwargs = cache_kwargs
if shared.args.cfg_cache:
self.ex_cache_negative = Cache(self.ex_model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs)
self.past_seq_negative = None
def _validate_model_class(self): def _validate_model_class(self):
pass pass
@ -132,7 +126,7 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
reset = True reset = True
# Maximum number of tokens to process in a single forward pass # Maximum number of tokens to process in a single forward pass
max_chunk_size = 2048 max_chunk_size = 256
# Make the forward call # Make the forward call
if labels is None: if labels is None:
@ -153,16 +147,17 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
# Process tokens from longest_prefix to second-to-last token # Process tokens from longest_prefix to second-to-last token
tokens_to_process = seq_tensor[longest_prefix:-1] tokens_to_process = seq_tensor[longest_prefix:-1]
# Use prefill() to fill the cache without computing logits # Process in chunks if the number of tokens is large
for i in range(0, tokens_to_process.shape[0], max_chunk_size): for i in range(0, tokens_to_process.shape[0], max_chunk_size):
chunk = tokens_to_process[i:i + max_chunk_size] chunk = tokens_to_process[i:i + max_chunk_size]
self.ex_model.prefill( self.ex_model.forward(
input_ids=chunk.view(1, -1), input_ids=chunk.view(1, -1),
params={ params={
"attn_mode": "flash_attn", "attn_mode": "flash_attn",
"cache": ex_cache, "cache": ex_cache,
"past_len": longest_prefix + i, "past_len": longest_prefix + i,
"batch_shape": (1, self.max_tokens), "batch_shape": (1, self.max_tokens),
"reconstruct": False # Force memory-efficient path
} }
) )
@ -173,17 +168,18 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
# Process all tokens except the last one # Process all tokens except the last one
tokens_to_process = seq_tensor[:-1] tokens_to_process = seq_tensor[:-1]
# Use prefill() to fill the cache without computing logits # Process in chunks if the number of tokens is large
current_len = 0 current_len = 0
for i in range(0, tokens_to_process.shape[0], max_chunk_size): for i in range(0, tokens_to_process.shape[0], max_chunk_size):
chunk = tokens_to_process[i:i + max_chunk_size] chunk = tokens_to_process[i:i + max_chunk_size]
self.ex_model.prefill( self.ex_model.forward(
input_ids=chunk.view(1, -1), input_ids=chunk.view(1, -1),
params={ params={
"attn_mode": "flash_attn", "attn_mode": "flash_attn",
"cache": ex_cache, "cache": ex_cache,
"past_len": current_len, "past_len": current_len,
"batch_shape": (1, self.max_tokens), "batch_shape": (1, self.max_tokens),
"reconstruct": False # Force memory-efficient path
} }
) )
current_len += chunk.shape[0] current_len += chunk.shape[0]
@ -198,26 +194,24 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
"cache": ex_cache, "cache": ex_cache,
"past_len": current_len, "past_len": current_len,
"batch_shape": (1, self.max_tokens), "batch_shape": (1, self.max_tokens),
"reconstruct": False # Force memory-efficient path
} }
).to(input_ids.device).float() ).to(input_ids.device).float()
else: else:
# Labels path: use cache for cross-chunk attention. # When processing with labels, handle as a complete sequence
# Process in chunks if the number of tokens is large
tokens_to_process = seq_tensor tokens_to_process = seq_tensor
all_logits = None all_logits = None
current_len = 0
for i in range(0, tokens_to_process.shape[0], max_chunk_size): for i in range(0, tokens_to_process.shape[0], max_chunk_size):
chunk = tokens_to_process[i:i + max_chunk_size] chunk = tokens_to_process[i:i + max_chunk_size]
chunk_logits = self.ex_model.forward( chunk_logits = self.ex_model.forward(
input_ids=chunk.view(1, -1), input_ids=chunk.view(1, -1),
params={ params={
"attn_mode": "flash_attn", "attn_mode": "flash_attn_nc", # No caching for training
"cache": ex_cache, "reconstruct": False # Force memory-efficient path
"past_len": current_len,
"batch_shape": (1, self.max_tokens),
} }
).float() ).float()
current_len += chunk.shape[0]
if all_logits is None: if all_logits is None:
all_logits = chunk_logits all_logits = chunk_logits

View file

@ -1,11 +1,11 @@
import importlib import importlib
import importlib.util
import sys
import traceback import traceback
from functools import partial from functools import partial
from inspect import signature from inspect import signature
from pathlib import Path from pathlib import Path
import gradio as gr
import modules.shared as shared import modules.shared as shared
from modules.logging_colors import logger from modules.logging_colors import logger
@ -38,15 +38,9 @@ def load_extensions():
try: try:
# Prefer user extension, fall back to system extension # Prefer user extension, fall back to system extension
user_script_path = shared.user_data_dir / 'extensions' / name / 'script.py' user_script_path = Path(f'user_data/extensions/{name}/script.py')
if user_script_path.exists(): if user_script_path.exists():
spec = importlib.util.spec_from_file_location( extension = importlib.import_module(f"user_data.extensions.{name}.script")
f"user_ext_{name}",
str(user_script_path)
)
extension = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = extension
spec.loader.exec_module(extension)
else: else:
extension = importlib.import_module(f"extensions.{name}.script") extension = importlib.import_module(f"extensions.{name}.script")
@ -59,7 +53,7 @@ def load_extensions():
state[name] = [True, i, extension] # Store extension object state[name] = [True, i, extension] # Store extension object
except ModuleNotFoundError: except ModuleNotFoundError:
extension_location = shared.user_data_dir / 'extensions' / name if user_script_path.exists() else Path('extensions') / name extension_location = Path('user_data/extensions') / name if user_script_path.exists() else Path('extensions') / name
windows_path = str(extension_location).replace('/', '\\') windows_path = str(extension_location).replace('/', '\\')
logger.error( logger.error(
f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\n" f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\n"
@ -212,7 +206,6 @@ def _apply_custom_js():
def create_extensions_block(): def create_extensions_block():
import gradio as gr
to_display = [] to_display = []
for extension, name in iterator(): for extension, name in iterator():
if hasattr(extension, "ui") and not (hasattr(extension, 'params') and extension.params.get('is_tab', False)): if hasattr(extension, "ui") and not (hasattr(extension, 'params') and extension.params.get('is_tab', False)):
@ -227,7 +220,6 @@ def create_extensions_block():
def create_extensions_tabs(): def create_extensions_tabs():
import gradio as gr
for extension, name in iterator(): for extension, name in iterator():
if hasattr(extension, "ui") and (hasattr(extension, 'params') and extension.params.get('is_tab', False)): if hasattr(extension, "ui") and (hasattr(extension, 'params') and extension.params.get('is_tab', False)):
display_name = getattr(extension, 'params', {}).get('display_name', name) display_name = getattr(extension, 'params', {}).get('display_name', name)

97
modules/gradio_hijack.py Normal file
View file

@ -0,0 +1,97 @@
'''
Most of the code here was adapted from:
https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14184
'''
import inspect
import warnings
from functools import wraps
import gradio as gr
import gradio.routes
import gradio.utils
from starlette.middleware.trustedhost import TrustedHostMiddleware
from modules import shared
orig_create_app = gradio.routes.App.create_app
# Be strict about only approving access to localhost by default
def create_app_with_trustedhost(*args, **kwargs):
app = orig_create_app(*args, **kwargs)
if not (shared.args.listen or shared.args.share):
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["localhost", "127.0.0.1"]
)
return app
gradio.routes.App.create_app = create_app_with_trustedhost
gradio.utils.launch_counter = lambda: None
class GradioDeprecationWarning(DeprecationWarning):
pass
def repair(grclass):
if not getattr(grclass, 'EVENTS', None):
return
@wraps(grclass.__init__)
def __repaired_init__(self, *args, tooltip=None, source=None, original=grclass.__init__, **kwargs):
if source:
kwargs["sources"] = [source]
allowed_kwargs = inspect.signature(original).parameters
fixed_kwargs = {}
for k, v in kwargs.items():
if k in allowed_kwargs:
fixed_kwargs[k] = v
else:
warnings.warn(f"unexpected argument for {grclass.__name__}: {k}", GradioDeprecationWarning, stacklevel=2)
original(self, *args, **fixed_kwargs)
self.webui_tooltip = tooltip
for event in self.EVENTS:
replaced_event = getattr(self, str(event))
def fun(*xargs, _js=None, replaced_event=replaced_event, **xkwargs):
if _js:
xkwargs['js'] = _js
return replaced_event(*xargs, **xkwargs)
setattr(self, str(event), fun)
grclass.__init__ = __repaired_init__
grclass.update = gr.update
for component in set(gr.components.__all__ + gr.layouts.__all__):
repair(getattr(gr, component, None))
class Dependency(gr.events.Dependency):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def then(*xargs, _js=None, **xkwargs):
if _js:
xkwargs['js'] = _js
return original_then(*xargs, **xkwargs)
original_then = self.then
self.then = then
gr.events.Dependency = Dependency
gr.Box = gr.Group

View file

@ -10,7 +10,6 @@ import markdown
from PIL import Image, ImageOps from PIL import Image, ImageOps
from modules import shared from modules import shared
from modules.reasoning import extract_reasoning
from modules.sane_markdown_lists import SaneListExtension from modules.sane_markdown_lists import SaneListExtension
from modules.utils import get_available_chat_styles from modules.utils import get_available_chat_styles
@ -110,40 +109,94 @@ def replace_blockquote(m):
def extract_thinking_block(string): def extract_thinking_block(string):
"""Extract thinking blocks from the beginning of an HTML-escaped string.""" """Extract thinking blocks from the beginning of a string."""
return extract_reasoning(string, html_escaped=True) if not string:
return None, string
THINK_START_TAG = "&lt;think&gt;"
THINK_END_TAG = "&lt;/think&gt;"
# Look for think tag first
start_pos = string.find(THINK_START_TAG)
end_pos = string.find(THINK_END_TAG)
# If think tags found, use existing logic
if start_pos != -1 or end_pos != -1:
# handle missing start or end tags
if start_pos == -1:
thought_start = 0
else:
thought_start = start_pos + len(THINK_START_TAG)
if end_pos == -1:
thought_end = len(string)
content_start = len(string)
else:
thought_end = end_pos
content_start = end_pos + len(THINK_END_TAG)
thinking_content = string[thought_start:thought_end]
remaining_content = string[content_start:]
return thinking_content, remaining_content
# If think tags not found, try GPT-OSS alternative format
ALT_START = "&lt;|channel|&gt;analysis&lt;|message|&gt;"
ALT_END = "&lt;|end|&gt;"
ALT_CONTENT_START = "&lt;|start|&gt;assistant&lt;|channel|&gt;final&lt;|message|&gt;"
alt_start_pos = string.find(ALT_START)
alt_end_pos = string.find(ALT_END)
alt_content_pos = string.find(ALT_CONTENT_START)
if alt_start_pos != -1 or alt_end_pos != -1:
if alt_start_pos == -1:
thought_start = 0
else:
thought_start = alt_start_pos + len(ALT_START)
# If no explicit end tag but content start exists, use content start as end
if alt_end_pos == -1:
if alt_content_pos != -1:
thought_end = alt_content_pos
content_start = alt_content_pos + len(ALT_CONTENT_START)
else:
thought_end = len(string)
content_start = len(string)
else:
thought_end = alt_end_pos
content_start = alt_content_pos + len(ALT_CONTENT_START) if alt_content_pos != -1 else alt_end_pos + len(ALT_END)
thinking_content = string[thought_start:thought_end]
remaining_content = string[content_start:]
return thinking_content, remaining_content
# Try seed:think format
SEED_START = "&lt;seed:think&gt;"
SEED_END = "&lt;/seed:think&gt;"
seed_start_pos = string.find(SEED_START)
seed_end_pos = string.find(SEED_END)
if seed_start_pos != -1 or seed_end_pos != -1:
if seed_start_pos == -1:
thought_start = 0
else:
thought_start = seed_start_pos + len(SEED_START)
if seed_end_pos == -1:
thought_end = len(string)
content_start = len(string)
else:
thought_end = seed_end_pos
content_start = seed_end_pos + len(SEED_END)
thinking_content = string[thought_start:thought_end]
remaining_content = string[content_start:]
return thinking_content, remaining_content
# Return if no format is found
return None, string
def build_thinking_block(thinking_content, message_id, has_remaining_content):
def build_tool_call_block(header, body, message_id, index):
"""Build HTML for a tool call accordion block."""
block_id = f"tool-call-{message_id}-{index}"
if body == '...':
# Pending placeholder — no expandable body, just title with ellipsis
return f'''
<details class="thinking-block" data-block-id="{block_id}">
<summary class="thinking-header">
{tool_svg_small}
<span class="thinking-title">{html.escape(header)} ...</span>
</summary>
</details>
'''
# Build a plain <pre> directly to avoid highlight.js auto-detection
escaped_body = html.escape(body)
return f'''
<details class="thinking-block" data-block-id="{block_id}">
<summary class="thinking-header">
{tool_svg_small}
<span class="thinking-title">{html.escape(header)}</span>
</summary>
<div class="thinking-content pretty_scrollbar"><pre><code class="nohighlight">{escaped_body}</code></pre></div>
</details>
'''
def build_thinking_block(thinking_content, message_id, has_remaining_content, thinking_index=0):
"""Build HTML for a thinking block.""" """Build HTML for a thinking block."""
if thinking_content is None: if thinking_content is None:
return None return None
@ -152,7 +205,7 @@ def build_thinking_block(thinking_content, message_id, has_remaining_content, th
thinking_html = process_markdown_content(thinking_content) thinking_html = process_markdown_content(thinking_content)
# Generate unique ID for the thinking block # Generate unique ID for the thinking block
block_id = f"thinking-{message_id}-{thinking_index}" block_id = f"thinking-{message_id}-0"
# Check if thinking is complete or still in progress # Check if thinking is complete or still in progress
is_streaming = not has_remaining_content is_streaming = not has_remaining_content
@ -185,27 +238,23 @@ def process_markdown_content(string):
if not string: if not string:
return "" return ""
# Define unique placeholders for LaTeX asterisks and underscores # Define a unique placeholder for LaTeX asterisks
LATEX_ASTERISK_PLACEHOLDER = "LATEXASTERISKPLACEHOLDER" LATEX_ASTERISK_PLACEHOLDER = "LATEXASTERISKPLACEHOLDER"
LATEX_UNDERSCORE_PLACEHOLDER = "LATEXUNDERSCOREPLACEHOLDER"
def protect_asterisks_underscores_in_latex(match): def protect_asterisks_in_latex(match):
"""A replacer function for re.sub to protect asterisks and underscores in multiple LaTeX formats.""" """A replacer function for re.sub to protect asterisks in multiple LaTeX formats."""
# Check which delimiter group was captured # Check which delimiter group was captured
if match.group(1) is not None: # Content from $$...$$ if match.group(1) is not None: # Content from $$...$$
content = match.group(1) content = match.group(1)
modified_content = content.replace('*', LATEX_ASTERISK_PLACEHOLDER) modified_content = content.replace('*', LATEX_ASTERISK_PLACEHOLDER)
modified_content = modified_content.replace('_', LATEX_UNDERSCORE_PLACEHOLDER) return f'$${modified_content}$$'
return f'{modified_content}'
elif match.group(2) is not None: # Content from \[...\] elif match.group(2) is not None: # Content from \[...\]
content = match.group(2) content = match.group(2)
modified_content = content.replace('*', LATEX_ASTERISK_PLACEHOLDER) modified_content = content.replace('*', LATEX_ASTERISK_PLACEHOLDER)
modified_content = modified_content.replace('_', LATEX_UNDERSCORE_PLACEHOLDER)
return f'\\[{modified_content}\\]' return f'\\[{modified_content}\\]'
elif match.group(3) is not None: # Content from \(...\) elif match.group(3) is not None: # Content from \(...\)
content = match.group(3) content = match.group(3)
modified_content = content.replace('*', LATEX_ASTERISK_PLACEHOLDER) modified_content = content.replace('*', LATEX_ASTERISK_PLACEHOLDER)
modified_content = modified_content.replace('_', LATEX_UNDERSCORE_PLACEHOLDER)
return f'\\({modified_content}\\)' return f'\\({modified_content}\\)'
return match.group(0) # Fallback return match.group(0) # Fallback
@ -239,10 +288,9 @@ def process_markdown_content(string):
string = string.replace('\\end{equation*}', '$$') string = string.replace('\\end{equation*}', '$$')
string = re.sub(r"(.)```", r"\1\n```", string) string = re.sub(r"(.)```", r"\1\n```", string)
# Protect asterisks and underscores within all LaTeX blocks before markdown conversion # Protect asterisks within all LaTeX blocks before markdown conversion
latex_pattern = re.compile(r'((?:^|[\r\n\s])\$\$[^`]*?\$\$)|\\\[(.*?)\\\]|\\\((.*?)\\\)', latex_pattern = re.compile(r'\$\$(.*?)\$\$|\\\[(.*?)\\\]|\\\((.*?)\\\)', re.DOTALL)
re.DOTALL) string = latex_pattern.sub(protect_asterisks_in_latex, string)
string = latex_pattern.sub(protect_asterisks_underscores_in_latex, string)
result = '' result = ''
is_code = False is_code = False
@ -254,11 +302,11 @@ def process_markdown_content(string):
if stripped_line.startswith('```'): if stripped_line.startswith('```'):
is_code = not is_code is_code = not is_code
elif stripped_line.startswith('$$') and (stripped_line == "$$" or not stripped_line.endswith('$$')): elif stripped_line.startswith('$$'):
is_latex = not is_latex is_latex = not is_latex
elif stripped_line.endswith('$$'): elif stripped_line.endswith('$$'):
is_latex = False is_latex = False
elif stripped_line.startswith('\\\\[') and not stripped_line.endswith('\\\\]'): elif stripped_line.startswith('\\\\['):
is_latex = True is_latex = True
elif stripped_line.startswith('\\\\]'): elif stripped_line.startswith('\\\\]'):
is_latex = False is_latex = False
@ -303,9 +351,8 @@ def process_markdown_content(string):
# Convert to HTML using markdown # Convert to HTML using markdown
html_output = markdown.markdown(result, extensions=['fenced_code', 'tables', SaneListExtension()]) html_output = markdown.markdown(result, extensions=['fenced_code', 'tables', SaneListExtension()])
# Restore the LaTeX asterisks and underscores after markdown conversion # Restore the LaTeX asterisks after markdown conversion
html_output = html_output.replace(LATEX_ASTERISK_PLACEHOLDER, '*') html_output = html_output.replace(LATEX_ASTERISK_PLACEHOLDER, '*')
html_output = html_output.replace(LATEX_UNDERSCORE_PLACEHOLDER, '_')
# Remove extra newlines before </code> # Remove extra newlines before </code>
html_output = re.sub(r'\s*</code>', '</code>', html_output) html_output = re.sub(r'\s*</code>', '</code>', html_output)
@ -317,9 +364,6 @@ def process_markdown_content(string):
# Unescape backslashes # Unescape backslashes
html_output = html_output.replace('\\\\', '\\') html_output = html_output.replace('\\\\', '\\')
# Wrap tables in a scrollable div
html_output = html_output.replace('<table>', '<div class="table-wrapper pretty_scrollbar"><table>').replace('</table>', '</table></div>')
return html_output return html_output
@ -336,67 +380,25 @@ def convert_to_markdown(string, message_id=None):
if message_id is None: if message_id is None:
message_id = "unknown" message_id = "unknown"
# Find tool call blocks by position, then process the text segments # Extract different components from the string
# between them using extract_thinking_block (which supports all
# THINKING_FORMATS, including end-only variants like Qwen's).
tool_call_pattern = re.compile(r'<tool_call>(.*?)\n(.*?)\n</tool_call>', re.DOTALL)
tool_calls = list(tool_call_pattern.finditer(string))
if not tool_calls:
# No tool calls — use original single-pass extraction
thinking_content, remaining_content = extract_thinking_block(string) thinking_content, remaining_content = extract_thinking_block(string)
# Build individual HTML blocks
blocks = [] blocks = []
# Add thinking block if present
thinking_html = build_thinking_block(thinking_content, message_id, bool(remaining_content)) thinking_html = build_thinking_block(thinking_content, message_id, bool(remaining_content))
if thinking_html: if thinking_html:
blocks.append(thinking_html) blocks.append(thinking_html)
# Add main content block
main_html = build_main_content_block(remaining_content) main_html = build_main_content_block(remaining_content)
if main_html: if main_html:
blocks.append(main_html) blocks.append(main_html)
# Assemble all blocks into final HTML
return ''.join(blocks) return ''.join(blocks)
# Split string into text segments around tool_call blocks and
# run extract_thinking_block on each segment for full format support.
html_parts = []
last_end = 0
tool_idx = 0
think_idx = 0
def process_text_segment(text, is_last_segment):
"""Process a text segment between tool_call blocks for thinking content."""
nonlocal think_idx
if not text.strip():
return
while text.strip():
thinking_content, remaining = extract_thinking_block(text)
if thinking_content is None:
break
has_remaining = bool(remaining.strip()) or not is_last_segment
html_parts.append(build_thinking_block(thinking_content, message_id, has_remaining, think_idx))
think_idx += 1
text = remaining
if text.strip():
html_parts.append(process_markdown_content(text))
for tc in tool_calls:
# Process text before this tool_call
process_text_segment(string[last_end:tc.start()], is_last_segment=False)
# Add tool call accordion
header = tc.group(1).strip()
body = tc.group(2).strip()
html_parts.append(build_tool_call_block(header, body, message_id, tool_idx))
tool_idx += 1
last_end = tc.end()
# Process text after the last tool_call
process_text_segment(string[last_end:], is_last_segment=True)
return ''.join(html_parts)
def convert_to_markdown_wrapped(string, message_id=None, use_cache=True): def convert_to_markdown_wrapped(string, message_id=None, use_cache=True):
''' '''
@ -453,7 +455,6 @@ branch_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24
edit_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="tabler-icon tabler-icon-pencil"><path d="M4 20h4l10.5 -10.5a2.828 2.828 0 1 0 -4 -4l-10.5 10.5v4"></path><path d="M13.5 6.5l4 4"></path></svg>''' edit_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="tabler-icon tabler-icon-pencil"><path d="M4 20h4l10.5 -10.5a2.828 2.828 0 1 0 -4 -4l-10.5 10.5v4"></path><path d="M13.5 6.5l4 4"></path></svg>'''
info_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="thinking-icon tabler-icon tabler-icon-info-circle"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M12 2a10 10 0 0 1 0 20a10 10 0 0 1 0 -20z" /><path d="M12 16v-4" /><path d="M12 8h.01" /></svg>''' info_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="thinking-icon tabler-icon tabler-icon-info-circle"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M12 2a10 10 0 0 1 0 20a10 10 0 0 1 0 -20z" /><path d="M12 16v-4" /><path d="M12 8h.01" /></svg>'''
info_svg_small = '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="thinking-icon tabler-icon tabler-icon-info-circle"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M12 2a10 10 0 0 1 0 20a10 10 0 0 1 0 -20z" /><path d="M12 16v-4" /><path d="M12 8h.01" /></svg>''' info_svg_small = '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="thinking-icon tabler-icon tabler-icon-info-circle"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M12 2a10 10 0 0 1 0 20a10 10 0 0 1 0 -20z" /><path d="M12 16v-4" /><path d="M12 8h.01" /></svg>'''
tool_svg_small = '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="thinking-icon tabler-icon tabler-icon-tool"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M7 10h3v-3l-3.5 -3.5a6 6 0 0 1 8 8l6 6a2 2 0 0 1 -3 3l-6 -6a6 6 0 0 1 -8 -8l3.5 3.5" /></svg>'''
attachment_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M21.44 11.05l-9.19 9.19a6 6 0 0 1-8.48-8.48l9.19-9.19a4 4 0 0 1 5.66 5.66l-9.2 9.19a2 2 0 0 1-2.83-2.83l8.49-8.48"></path></svg>''' attachment_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M21.44 11.05l-9.19 9.19a6 6 0 0 1-8.48-8.48l9.19-9.19a4 4 0 0 1 5.66 5.66l-9.2 9.19a2 2 0 0 1-2.83-2.83l8.49-8.48"></path></svg>'''
copy_button = f'<button class="footer-button footer-copy-button" title="Copy" onclick="copyToClipboard(this)">{copy_svg}</button>' copy_button = f'<button class="footer-button footer-copy-button" title="Copy" onclick="copyToClipboard(this)">{copy_svg}</button>'
@ -647,10 +648,10 @@ def generate_instruct_html(history, last_message_only=False):
def get_character_image_with_cache_buster(): def get_character_image_with_cache_buster():
"""Get character image URL with cache busting based on file modification time""" """Get character image URL with cache busting based on file modification time"""
cache_path = shared.user_data_dir / "cache" / "pfp_character_thumb.png" cache_path = Path("user_data/cache/pfp_character_thumb.png")
if cache_path.exists(): if cache_path.exists():
mtime = int(cache_path.stat().st_mtime) mtime = int(cache_path.stat().st_mtime)
return f'<img src="file/{shared.user_data_dir}/cache/pfp_character_thumb.png?{mtime}" class="pfp_character">' return f'<img src="file/user_data/cache/pfp_character_thumb.png?{mtime}" class="pfp_character">'
return '' return ''
@ -674,8 +675,8 @@ def generate_cai_chat_html(history, name1, name2, style, character, reset_cache=
# Get appropriate image # Get appropriate image
if role == "user": if role == "user":
img = (f'<img src="file/{shared.user_data_dir}/cache/pfp_me.png?{time.time() if reset_cache else ""}">' img = (f'<img src="file/user_data/cache/pfp_me.png?{time.time() if reset_cache else ""}">'
if (shared.user_data_dir / "cache" / "pfp_me.png").exists() else '') if Path("user_data/cache/pfp_me.png").exists() else '')
else: else:
img = img_bot img = img_bot

View file

@ -77,18 +77,7 @@ def process_message_content(content: Any) -> Tuple[str, List[Image.Image]]:
# Support external URLs # Support external URLs
try: try:
import requests import requests
from urllib.parse import urljoin response = requests.get(image_url, timeout=10)
from modules.web_search import _validate_url
_validate_url(image_url)
url = image_url
for _ in range(5):
response = requests.get(url, timeout=10, allow_redirects=False)
if response.is_redirect and 'Location' in response.headers:
url = urljoin(url, response.headers['Location'])
_validate_url(url)
else:
break
response.raise_for_status() response.raise_for_status()
image_data = response.content image_data = response.content
image = Image.open(io.BytesIO(image_data)) image = Image.open(io.BytesIO(image_data))

View file

@ -36,7 +36,6 @@ class LlamaServer:
self.process = None self.process = None
self.session = requests.Session() self.session = requests.Session()
self.vocabulary_size = None self.vocabulary_size = None
self.n_ctx = None
self.bos_token = "<s>" self.bos_token = "<s>"
self.last_prompt_token_count = 0 self.last_prompt_token_count = 0
@ -76,8 +75,6 @@ class LlamaServer:
"top_p": state["top_p"], "top_p": state["top_p"],
"min_p": state["min_p"], "min_p": state["min_p"],
"top_n_sigma": state["top_n_sigma"] if state["top_n_sigma"] > 0 else -1, "top_n_sigma": state["top_n_sigma"] if state["top_n_sigma"] > 0 else -1,
"adaptive_target": state["adaptive_target"] if state["adaptive_target"] > 0 else -1,
"adaptive_decay": state["adaptive_decay"],
"typical_p": state["typical_p"], "typical_p": state["typical_p"],
"repeat_penalty": state["repetition_penalty"], "repeat_penalty": state["repetition_penalty"],
"repeat_last_n": state["repetition_penalty_range"], "repeat_last_n": state["repetition_penalty_range"],
@ -122,32 +119,15 @@ class LlamaServer:
penalty_found = True penalty_found = True
# Move temperature to the end if temperature_last is true and temperature exists in the list # Move temperature to the end if temperature_last is true and temperature exists in the list
if state["temperature_last"] and "temperature" in filtered_samplers: if state["temperature_last"] and "temperature" in samplers:
filtered_samplers.remove("temperature") samplers.remove("temperature")
filtered_samplers.append("temperature") samplers.append("temperature")
# adaptive-p replaces the default dist sampler; llama.cpp always
# places it at the end of the chain regardless of position, so we
# activate it based on the parameter value rather than sampler order.
if state.get("adaptive_target", 0) > 0:
filtered_samplers.append("adaptive_p")
payload["samplers"] = filtered_samplers payload["samplers"] = filtered_samplers
logit_bias = []
if state['custom_token_bans']: if state['custom_token_bans']:
logit_bias.extend([[int(token_id.strip()), False] for token_id in state['custom_token_bans'].split(',') if token_id.strip()]) to_ban = [[int(token_id), False] for token_id in state['custom_token_bans'].split(',')]
payload["logit_bias"] = to_ban
if state.get('logit_bias'):
for token_id_str, bias in state['logit_bias'].items():
logit_bias.append([int(token_id_str), bias])
if logit_bias:
payload["logit_bias"] = logit_bias
n_probs = state.get('logprobs', 0)
if n_probs and n_probs > 0:
payload["n_probs"] = n_probs
return payload return payload
@ -220,19 +200,16 @@ class LlamaServer:
# Make the generation request # Make the generation request
response = self.session.post(url, json=payload, stream=True) response = self.session.post(url, json=payload, stream=True)
try: try:
if response.status_code == 400 and response.json().get("error", {}).get("type") == "exceed_context_size_error": if response.status_code == 400 and response.json()["error"]["type"] == "exceed_context_size_error":
logger.error("The request exceeds the available context size, try increasing it") logger.error("The request exceeds the available context size, try increasing it")
return
else: else:
response.raise_for_status() # Raise an exception for HTTP errors response.raise_for_status() # Raise an exception for HTTP errors
full_text = "" full_text = ""
self.last_completion_probabilities = []
# Process the streaming response # Process the streaming response
stop_event = state.get('stop_event')
for line in response.iter_lines(): for line in response.iter_lines():
if shared.stop_everything or (stop_event and stop_event.is_set()): if shared.stop_everything:
break break
if not line: if not line:
@ -253,10 +230,6 @@ class LlamaServer:
full_text += data['content'] full_text += data['content']
yield full_text yield full_text
# Capture logprobs if present
if 'completion_probabilities' in data:
self.last_completion_probabilities.extend(data['completion_probabilities'])
# Check if generation is complete # Check if generation is complete
if data.get('stop', False): if data.get('stop', False):
break break
@ -305,8 +278,6 @@ class LlamaServer:
return result["completion_probabilities"][0]["top_probs"] return result["completion_probabilities"][0]["top_probs"]
else: else:
return result["completion_probabilities"][0]["top_logprobs"] return result["completion_probabilities"][0]["top_logprobs"]
time.sleep(0.05)
else: else:
raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}") raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}")
@ -321,35 +292,16 @@ class LlamaServer:
self.vocabulary_size = model_info["meta"]["n_vocab"] self.vocabulary_size = model_info["meta"]["n_vocab"]
def _get_bos_token(self): def _get_bos_token(self):
"""Get and store the model's BOS token and context size.""" """Get and store the model's BOS token."""
url = f"http://127.0.0.1:{self.port}/props" url = f"http://127.0.0.1:{self.port}/props"
response = self.session.get(url).json() response = self.session.get(url).json()
if "bos_token" in response: if "bos_token" in response:
self.bos_token = response["bos_token"] self.bos_token = response["bos_token"]
# Get actual n_ctx from the server (important when --fit auto-selects it)
n_ctx = response.get("default_generation_settings", {}).get("n_ctx")
if n_ctx:
self.n_ctx = n_ctx
def _is_port_available(self, port):
"""Check if a port is available for use."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind(('', port))
return True
except OSError:
return False
def _find_available_port(self): def _find_available_port(self):
"""Find an available port, preferring main port + 5.""" """Find an available port by letting the OS assign one."""
preferred_port = shared.args.api_port + 5
if self._is_port_available(preferred_port):
return preferred_port
# Fall back to OS-assigned random port
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0)) s.bind(('', 0)) # Bind to port 0 to get an available port
return s.getsockname()[1] return s.getsockname()[1]
def _start_server(self): def _start_server(self):
@ -362,6 +314,8 @@ class LlamaServer:
cmd = [ cmd = [
self.server_path, self.server_path,
"--model", self.model_path, "--model", self.model_path,
"--ctx-size", str(shared.args.ctx_size),
"--gpu-layers", str(shared.args.gpu_layers),
"--batch-size", str(shared.args.batch_size), "--batch-size", str(shared.args.batch_size),
"--ubatch-size", str(shared.args.ubatch_size), "--ubatch-size", str(shared.args.ubatch_size),
"--port", str(self.port), "--port", str(self.port),
@ -369,19 +323,6 @@ class LlamaServer:
"--flash-attn", "on", "--flash-attn", "on",
] ]
if shared.args.ctx_size > 0:
cmd += ["--ctx-size", str(shared.args.ctx_size)]
elif shared.args.gpu_layers >= 0:
cmd += ["--ctx-size", "8192"]
if shared.args.gpu_layers >= 0:
cmd += ["--gpu-layers", str(shared.args.gpu_layers), "--fit", "off"]
else:
cmd += ["--fit", "on"]
cmd += ["--fit-ctx", "8192"]
if shared.args.fit_target:
cmd += ["--fit-target", shared.args.fit_target]
if shared.args.threads > 0: if shared.args.threads > 0:
cmd += ["--threads", str(shared.args.threads)] cmd += ["--threads", str(shared.args.threads)]
if shared.args.threads_batch > 0: if shared.args.threads_batch > 0:
@ -404,10 +345,14 @@ class LlamaServer:
if shared.args.cache_type != "fp16" and shared.args.cache_type in llamacpp_valid_cache_types: if shared.args.cache_type != "fp16" and shared.args.cache_type in llamacpp_valid_cache_types:
cmd += ["--cache-type-k", shared.args.cache_type, "--cache-type-v", shared.args.cache_type] cmd += ["--cache-type-k", shared.args.cache_type, "--cache-type-v", shared.args.cache_type]
cache_type = shared.args.cache_type cache_type = shared.args.cache_type
if shared.args.compress_pos_emb != 1:
cmd += ["--rope-freq-scale", str(1.0 / shared.args.compress_pos_emb)]
if shared.args.rope_freq_base > 0:
cmd += ["--rope-freq-base", str(shared.args.rope_freq_base)]
if shared.args.mmproj not in [None, 'None']: if shared.args.mmproj not in [None, 'None']:
path = Path(shared.args.mmproj) path = Path(shared.args.mmproj)
if not path.exists(): if not path.exists():
path = shared.user_data_dir / 'mmproj' / shared.args.mmproj path = Path('user_data/mmproj') / shared.args.mmproj
if path.exists(): if path.exists():
cmd += ["--mmproj", str(path)] cmd += ["--mmproj", str(path)]
@ -419,7 +364,7 @@ class LlamaServer:
else: else:
model_file = sorted(path.glob('*.gguf'))[0] model_file = sorted(path.glob('*.gguf'))[0]
cmd += ["--model-draft", str(model_file)] cmd += ["--model-draft", model_file]
if shared.args.draft_max > 0: if shared.args.draft_max > 0:
cmd += ["--draft-max", str(shared.args.draft_max)] cmd += ["--draft-max", str(shared.args.draft_max)]
if shared.args.gpu_layers_draft > 0: if shared.args.gpu_layers_draft > 0:
@ -428,13 +373,6 @@ class LlamaServer:
cmd += ["--device-draft", shared.args.device_draft] cmd += ["--device-draft", shared.args.device_draft]
if shared.args.ctx_size_draft > 0: if shared.args.ctx_size_draft > 0:
cmd += ["--ctx-size-draft", str(shared.args.ctx_size_draft)] cmd += ["--ctx-size-draft", str(shared.args.ctx_size_draft)]
if shared.args.spec_type != 'none':
cmd += ["--spec-type", shared.args.spec_type]
cmd += ["--draft-max", str(shared.args.draft_max)]
cmd += ["--spec-ngram-size-n", str(shared.args.spec_ngram_size_n)]
cmd += ["--spec-ngram-size-m", str(shared.args.spec_ngram_size_m)]
cmd += ["--spec-ngram-min-hits", str(shared.args.spec_ngram_min_hits)]
cmd += ["--parallel", str(shared.args.parallel)]
if shared.args.streaming_llm: if shared.args.streaming_llm:
cmd += ["--cache-reuse", "1"] cmd += ["--cache-reuse", "1"]
cmd += ["--swa-full"] cmd += ["--swa-full"]
@ -447,11 +385,8 @@ class LlamaServer:
extra_flags = extra_flags[1:-1].strip() extra_flags = extra_flags[1:-1].strip()
for flag_item in extra_flags.split(','): for flag_item in extra_flags.split(','):
flag_item = flag_item.strip()
if '=' in flag_item: if '=' in flag_item:
flag, value = flag_item.split('=', 1) flag, value = flag_item.split('=', 1)
flag = flag.strip()
value = value.strip()
if len(flag) <= 3: if len(flag) <= 3:
cmd += [f"-{flag}", value] cmd += [f"-{flag}", value]
else: else:
@ -475,9 +410,7 @@ class LlamaServer:
print(' '.join(str(item) for item in cmd[1:])) print(' '.join(str(item) for item in cmd[1:]))
print() print()
gpu_layers_str = "auto" if shared.args.gpu_layers < 0 else str(shared.args.gpu_layers) logger.info(f"Using gpu_layers={shared.args.gpu_layers} | ctx_size={shared.args.ctx_size} | cache_type={cache_type}")
ctx_size_str = "auto" if shared.args.ctx_size == 0 and shared.args.gpu_layers < 0 else str(shared.args.ctx_size or 8192)
logger.info(f"Using gpu_layers={gpu_layers_str} | ctx_size={ctx_size_str} | cache_type={cache_type}")
# Start the server with pipes for output # Start the server with pipes for output
self.process = subprocess.Popen( self.process = subprocess.Popen(
cmd, cmd,
@ -501,7 +434,7 @@ class LlamaServer:
response = self.session.get(health_url) response = self.session.get(health_url)
if response.status_code == 200: if response.status_code == 200:
break break
except Exception: except:
pass pass
time.sleep(1) time.sleep(1)
@ -531,7 +464,6 @@ class LlamaServer:
self.process.wait(timeout=5) self.process.wait(timeout=5)
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
self.process.kill() self.process.kill()
self.process.wait(timeout=5)
self.process = None self.process = None
@ -542,8 +474,6 @@ def filter_stderr_with_progress(process_stderr):
inline (overwriting the same line) until completion. inline (overwriting the same line) until completion.
""" """
progress_re = re.compile(r'slot update_slots: id.*progress = (\d+\.\d+)') progress_re = re.compile(r'slot update_slots: id.*progress = (\d+\.\d+)')
ansi_re = re.compile(r'\x1b\[[0-9;]*[a-zA-Z]')
log_prefix_re = re.compile(r'^[IWED] ')
last_was_progress = False last_was_progress = False
try: try:
@ -562,7 +492,6 @@ def filter_stderr_with_progress(process_stderr):
line_bytes, buffer = buffer.split(b'\n', 1) line_bytes, buffer = buffer.split(b'\n', 1)
try: try:
line = line_bytes.decode('utf-8', errors='replace').strip('\r\n') line = line_bytes.decode('utf-8', errors='replace').strip('\r\n')
line = log_prefix_re.sub('', ansi_re.sub('', line))
if line: # Process non-empty lines if line: # Process non-empty lines
match = progress_re.search(line) match = progress_re.search(line)
@ -582,7 +511,7 @@ def filter_stderr_with_progress(process_stderr):
last_was_progress = (progress < 1.0) last_was_progress = (progress < 1.0)
# skip noise lines # skip noise lines
elif not (line.startswith(('srv ', 'slot ')) or 'log_server_r: request: GET /health' in line or 'No parser definition detected' in line): elif not (line.startswith(('srv ', 'slot ')) or 'log_server_r: request: GET /health' in line):
# if we were in progress, finish that line first # if we were in progress, finish that line first
if last_was_progress: if last_was_progress:
print(file=sys.stderr) print(file=sys.stderr)
@ -598,5 +527,5 @@ def filter_stderr_with_progress(process_stderr):
finally: finally:
try: try:
process_stderr.close() process_stderr.close()
except Exception: except:
pass pass

View file

@ -1,10 +1,11 @@
import functools import functools
from collections import OrderedDict from collections import OrderedDict
import gradio as gr
loaders_and_params = OrderedDict({ loaders_and_params = OrderedDict({
'llama.cpp': [ 'llama.cpp': [
'gpu_layers', 'gpu_layers',
'fit_target',
'cpu_moe', 'cpu_moe',
'threads', 'threads',
'threads_batch', 'threads_batch',
@ -15,22 +16,18 @@ loaders_and_params = OrderedDict({
'tensor_split', 'tensor_split',
'extra_flags', 'extra_flags',
'streaming_llm', 'streaming_llm',
'rope_freq_base',
'compress_pos_emb',
'row_split', 'row_split',
'no_kv_offload', 'no_kv_offload',
'no_mmap', 'no_mmap',
'mlock', 'mlock',
'numa', 'numa',
'parallel',
'model_draft', 'model_draft',
'draft_max', 'draft_max',
'gpu_layers_draft', 'gpu_layers_draft',
'device_draft', 'device_draft',
'ctx_size_draft', 'ctx_size_draft',
'ngram_header',
'spec_type',
'spec_ngram_size_n',
'spec_ngram_size_m',
'spec_ngram_min_hits',
'speculative_decoding_accordion', 'speculative_decoding_accordion',
'mmproj', 'mmproj',
'mmproj_accordion', 'mmproj_accordion',
@ -39,6 +36,8 @@ loaders_and_params = OrderedDict({
'Transformers': [ 'Transformers': [
'gpu_split', 'gpu_split',
'cpu_memory', 'cpu_memory',
'alpha_value',
'compress_pos_emb',
'compute_dtype', 'compute_dtype',
'quant_type', 'quant_type',
'load_in_8bit', 'load_in_8bit',
@ -65,12 +64,46 @@ loaders_and_params = OrderedDict({
'gpu_split', 'gpu_split',
'model_draft', 'model_draft',
'draft_max', 'draft_max',
'ctx_size_draft',
'speculative_decoding_accordion', 'speculative_decoding_accordion',
'enable_tp', 'enable_tp',
'tp_backend', 'tp_backend',
], ],
'ExLlamav2_HF': [
'ctx_size',
'cache_type',
'gpu_split',
'alpha_value',
'compress_pos_emb',
'num_experts_per_token',
'autosplit',
'enable_tp',
'no_flash_attn',
'no_xformers',
'no_sdpa',
'cfg_cache',
'no_use_fast',
],
'ExLlamav2': [
'ctx_size',
'cache_type',
'gpu_split',
'alpha_value',
'compress_pos_emb',
'num_experts_per_token',
'autosplit',
'enable_tp',
'no_flash_attn',
'no_xformers',
'no_sdpa',
'model_draft',
'draft_max',
'ctx_size_draft',
'speculative_decoding_accordion',
],
'TensorRT-LLM': [ 'TensorRT-LLM': [
'ctx_size', 'ctx_size',
'cpp_runner',
'tensorrt_llm_info', 'tensorrt_llm_info',
] ]
}) })
@ -95,8 +128,6 @@ def transformers_samplers():
'tfs', 'tfs',
'top_a', 'top_a',
'top_n_sigma', 'top_n_sigma',
'adaptive_target',
'adaptive_decay',
'dry_multiplier', 'dry_multiplier',
'dry_allowed_length', 'dry_allowed_length',
'dry_base', 'dry_base',
@ -152,8 +183,54 @@ loaders_samplers = {
'tfs', 'tfs',
'top_a', 'top_a',
'top_n_sigma', 'top_n_sigma',
'adaptive_target', 'dry_multiplier',
'adaptive_decay', 'dry_allowed_length',
'dry_base',
'repetition_penalty',
'frequency_penalty',
'presence_penalty',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
'repetition_penalty_range',
'guidance_scale',
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'do_sample',
'dynamic_temperature',
'temperature_last',
'auto_max_new_tokens',
'ban_eos_token',
'add_bos_token',
'enable_thinking',
'reasoning_effort',
'skip_special_tokens',
'seed',
'sampler_priority',
'custom_token_bans',
'negative_prompt',
'dry_sequence_breakers',
'grammar_string',
'grammar_file_row',
},
'ExLlamav2_HF': {
'temperature',
'dynatemp_low',
'dynatemp_high',
'dynatemp_exponent',
'smoothing_factor',
'smoothing_curve',
'min_p',
'top_p',
'top_k',
'typical_p',
'xtc_threshold',
'xtc_probability',
'epsilon_cutoff',
'eta_cutoff',
'tfs',
'top_a',
'top_n_sigma',
'dry_multiplier', 'dry_multiplier',
'dry_allowed_length', 'dry_allowed_length',
'dry_base', 'dry_base',
@ -189,8 +266,6 @@ loaders_samplers = {
'min_p', 'min_p',
'top_p', 'top_p',
'top_k', 'top_k',
'adaptive_target',
'adaptive_decay',
'repetition_penalty', 'repetition_penalty',
'frequency_penalty', 'frequency_penalty',
'presence_penalty', 'presence_penalty',
@ -201,10 +276,45 @@ loaders_samplers = {
'ban_eos_token', 'ban_eos_token',
'add_bos_token', 'add_bos_token',
'enable_thinking', 'enable_thinking',
'reasoning_effort',
'seed', 'seed',
'skip_special_tokens', 'skip_special_tokens',
}, },
'ExLlamav2': {
'temperature',
'dynatemp_low',
'dynatemp_high',
'dynatemp_exponent',
'smoothing_factor',
'min_p',
'top_p',
'top_k',
'typical_p',
'xtc_threshold',
'xtc_probability',
'tfs',
'top_a',
'dry_multiplier',
'dry_allowed_length',
'dry_base',
'repetition_penalty',
'frequency_penalty',
'presence_penalty',
'repetition_penalty_range',
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'dynamic_temperature',
'temperature_last',
'auto_max_new_tokens',
'ban_eos_token',
'add_bos_token',
'enable_thinking',
'reasoning_effort',
'skip_special_tokens',
'seed',
'custom_token_bans',
'dry_sequence_breakers',
},
'llama.cpp': { 'llama.cpp': {
'temperature', 'temperature',
'dynatemp_low', 'dynatemp_low',
@ -217,8 +327,6 @@ loaders_samplers = {
'xtc_threshold', 'xtc_threshold',
'xtc_probability', 'xtc_probability',
'top_n_sigma', 'top_n_sigma',
'adaptive_target',
'adaptive_decay',
'dry_multiplier', 'dry_multiplier',
'dry_allowed_length', 'dry_allowed_length',
'dry_base', 'dry_base',
@ -238,7 +346,6 @@ loaders_samplers = {
'reasoning_effort', 'reasoning_effort',
'seed', 'seed',
'sampler_priority', 'sampler_priority',
'custom_token_bans',
'dry_sequence_breakers', 'dry_sequence_breakers',
'grammar_string', 'grammar_string',
'grammar_file_row', 'grammar_file_row',
@ -247,16 +354,11 @@ loaders_samplers = {
'temperature', 'temperature',
'top_p', 'top_p',
'top_k', 'top_k',
'min_p',
'repetition_penalty', 'repetition_penalty',
'frequency_penalty', 'frequency_penalty',
'presence_penalty', 'presence_penalty',
'no_repeat_ngram_size',
'auto_max_new_tokens', 'auto_max_new_tokens',
'ban_eos_token', 'ban_eos_token',
'add_bos_token',
'skip_special_tokens',
'seed',
} }
} }
@ -272,7 +374,6 @@ def list_all_samplers():
def blacklist_samplers(loader, dynamic_temperature): def blacklist_samplers(loader, dynamic_temperature):
import gradio as gr
all_samplers = list_all_samplers() all_samplers = list_all_samplers()
output = [] output = []
@ -298,58 +399,7 @@ def get_all_params():
return sorted(all_params) return sorted(all_params)
def list_model_elements():
return [
'filter_by_loader',
'loader',
'cpu_memory',
'gpu_layers',
'fit_target',
'cpu_moe',
'threads',
'threads_batch',
'batch_size',
'ubatch_size',
'ctx_size',
'cache_type',
'tensor_split',
'extra_flags',
'streaming_llm',
'gpu_split',
'compute_dtype',
'quant_type',
'load_in_8bit',
'load_in_4bit',
'attn_implementation',
'cpu',
'disk',
'row_split',
'no_kv_offload',
'no_mmap',
'mlock',
'numa',
'parallel',
'use_double_quant',
'bf16',
'enable_tp',
'tp_backend',
'cfg_cache',
'no_use_fast',
'model_draft',
'draft_max',
'gpu_layers_draft',
'device_draft',
'ctx_size_draft',
'spec_type',
'spec_ngram_size_n',
'spec_ngram_size_m',
'spec_ngram_min_hits',
'mmproj',
]
def make_loader_params_visible(loader): def make_loader_params_visible(loader):
import gradio as gr
params = [] params = []
all_params = get_all_params() all_params = get_all_params()
if loader in loaders_and_params: if loader in loaders_and_params:

View file

@ -70,21 +70,26 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
from modules import sampler_hijack from modules import sampler_hijack
from modules.torch_utils import get_device from modules.torch_utils import get_device
is_non_hf_exllamav2 = shared.model.__class__.__name__ == 'Exllamav2Model'
is_non_hf_exllamav3 = shared.model.__class__.__name__ == 'Exllamav3Model' is_non_hf_exllamav3 = shared.model.__class__.__name__ == 'Exllamav3Model'
if not use_samplers: if not use_samplers:
state = {'stream': True} state = {'stream': True}
if use_samplers: if use_samplers:
if is_non_hf_exllamav2:
# sampling is all done in C++ for exllama, so it is really hard to hijack
logger.error("Sampler hijacking is not supported non-Huggingface loaders.")
return 'Error: Sampler hijacking is not supported non-Huggingface loaders. Please disable the "Use samplers" option.', previous
state['max_new_tokens'] = 1 state['max_new_tokens'] = 1
state['auto_max_new_tokens'] = False state['auto_max_new_tokens'] = False
state.setdefault('stream', True)
for _ in generate_reply(prompt, state): for _ in generate_reply(prompt, state):
pass pass
scores = sampler_hijack.global_scores[-1] scores = sampler_hijack.global_scores[-1]
else: else:
if is_non_hf_exllamav3: if is_non_hf_exllamav2 or is_non_hf_exllamav3:
device = get_device() device = get_device()
tokens = shared.tokenizer.encode(prompt) tokens = shared.tokenizer.encode(prompt)
if device: if device:
@ -100,7 +105,7 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
output = shared.model(input_ids=tokens) output = shared.model(input_ids=tokens)
scores = output['logits'][-1][-1] scores = output['logits'][-1][-1]
probs = torch.softmax(scores.detach(), dim=-1, dtype=torch.float) probs = torch.softmax(scores, dim=-1, dtype=torch.float)
topk_values, topk_indices = torch.topk(probs, k=top_logits, largest=True, sorted=True) topk_values, topk_indices = torch.topk(probs, k=top_logits, largest=True, sorted=True)
if hasattr(shared.tokenizer, 'convert_ids_to_tokens'): if hasattr(shared.tokenizer, 'convert_ids_to_tokens'):
tokens = [shared.tokenizer.convert_ids_to_tokens(int(i)) for i in topk_indices] tokens = [shared.tokenizer.convert_ids_to_tokens(int(i)) for i in topk_indices]
@ -115,7 +120,7 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
if isinstance(key, bytes): if isinstance(key, bytes):
try: try:
key = key.decode() key = key.decode()
except Exception: except:
key = key.decode('latin') key = key.decode('latin')
output[key] = row[0] output[key] = row[0]

View file

@ -53,7 +53,7 @@ def get_single(value_type, file):
value = file.read(value_length) value = file.read(value_length)
try: try:
value = value.decode('utf-8') value = value.decode('utf-8')
except Exception: except:
pass pass
else: else:
type_str = _simple_value_packing.get(value_type) type_str = _simple_value_packing.get(value_type)

View file

@ -20,6 +20,8 @@ def load_model(model_name, loader=None):
'Transformers': transformers_loader, 'Transformers': transformers_loader,
'ExLlamav3_HF': ExLlamav3_HF_loader, 'ExLlamav3_HF': ExLlamav3_HF_loader,
'ExLlamav3': ExLlamav3_loader, 'ExLlamav3': ExLlamav3_loader,
'ExLlamav2_HF': ExLlamav2_HF_loader,
'ExLlamav2': ExLlamav2_loader,
'TensorRT-LLM': TensorRT_LLM_loader, 'TensorRT-LLM': TensorRT_LLM_loader,
} }
@ -38,9 +40,6 @@ def load_model(model_name, loader=None):
sampler_hijack.hijack_samplers() sampler_hijack.hijack_samplers()
shared.args.loader = loader shared.args.loader = loader
if loader != 'llama.cpp' and shared.args.ctx_size == 0:
shared.args.ctx_size = 8192
output = load_func_map[loader](model_name) output = load_func_map[loader](model_name)
if type(output) is tuple: if type(output) is tuple:
model, tokenizer = output model, tokenizer = output
@ -55,10 +54,7 @@ def load_model(model_name, loader=None):
shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings}) shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings})
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp': if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp':
if shared.args.ctx_size > 0:
shared.settings['truncation_length'] = shared.args.ctx_size shared.settings['truncation_length'] = shared.args.ctx_size
elif loader == 'llama.cpp' and hasattr(model, 'n_ctx') and model.n_ctx:
shared.settings['truncation_length'] = model.n_ctx
shared.is_multimodal = False shared.is_multimodal = False
if loader.lower() in ('exllamav3', 'llama.cpp') and hasattr(model, 'is_multimodal'): if loader.lower() in ('exllamav3', 'llama.cpp') and hasattr(model, 'is_multimodal'):
@ -112,6 +108,19 @@ def ExLlamav3_loader(model_name):
return model, tokenizer return model, tokenizer
def ExLlamav2_HF_loader(model_name):
from modules.exllamav2_hf import Exllamav2HF
return Exllamav2HF.from_pretrained(model_name)
def ExLlamav2_loader(model_name):
from modules.exllamav2 import Exllamav2Model
model, tokenizer = Exllamav2Model.from_pretrained(model_name)
return model, tokenizer
def TensorRT_LLM_loader(model_name): def TensorRT_LLM_loader(model_name):
try: try:
from modules.tensorrt_llm import TensorRTLLMModel from modules.tensorrt_llm import TensorRTLLMModel
@ -119,7 +128,7 @@ def TensorRT_LLM_loader(model_name):
raise ModuleNotFoundError("Failed to import 'tensorrt_llm'. Please install it manually following the instructions in the TensorRT-LLM GitHub repository.") raise ModuleNotFoundError("Failed to import 'tensorrt_llm'. Please install it manually following the instructions in the TensorRT-LLM GitHub repository.")
model = TensorRTLLMModel.from_pretrained(model_name) model = TensorRTLLMModel.from_pretrained(model_name)
return model, model.tokenizer return model
def unload_model(keep_model_name=False): def unload_model(keep_model_name=False):
@ -129,10 +138,10 @@ def unload_model(keep_model_name=False):
model_class_name = shared.model.__class__.__name__ model_class_name = shared.model.__class__.__name__
is_llamacpp = (model_class_name == 'LlamaServer') is_llamacpp = (model_class_name == 'LlamaServer')
if model_class_name in ['Exllamav3Model', 'Exllamav3HF', 'TensorRTLLMModel']: if model_class_name in ['Exllamav3Model', 'Exllamav3HF']:
shared.model.unload()
elif model_class_name in ['Exllamav2Model', 'Exllamav2HF'] and hasattr(shared.model, 'unload'):
shared.model.unload() shared.model.unload()
elif model_class_name == 'LlamaServer':
shared.model.stop()
shared.model = shared.tokenizer = None shared.model = shared.tokenizer = None
shared.lora_names = [] shared.lora_names = []

View file

@ -1,12 +1,14 @@
import functools import functools
import json import json
import re import re
import subprocess
from math import floor from math import floor
from pathlib import Path from pathlib import Path
import gradio as gr
import yaml import yaml
from modules import loaders, metadata_gguf, shared from modules import chat, loaders, metadata_gguf, shared, ui
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.utils import resolve_model_path from modules.utils import resolve_model_path
@ -15,6 +17,9 @@ def get_fallback_settings():
return { return {
'bf16': False, 'bf16': False,
'ctx_size': 8192, 'ctx_size': 8192,
'rope_freq_base': 0,
'compress_pos_emb': 1,
'alpha_value': 1,
'truncation_length': shared.settings['truncation_length'], 'truncation_length': shared.settings['truncation_length'],
'truncation_length_info': shared.settings['truncation_length'], 'truncation_length_info': shared.settings['truncation_length'],
'skip_special_tokens': shared.settings['skip_special_tokens'], 'skip_special_tokens': shared.settings['skip_special_tokens'],
@ -64,19 +69,21 @@ def get_model_metadata(model):
for k in metadata: for k in metadata:
if k.endswith('.context_length'): if k.endswith('.context_length'):
model_settings['ctx_size'] = 0 model_settings['ctx_size'] = min(metadata[k], 8192)
model_settings['truncation_length_info'] = metadata[k] model_settings['truncation_length_info'] = metadata[k]
elif k.endswith('rope.freq_base'):
model_settings['rope_freq_base'] = metadata[k]
elif k.endswith('rope.scale_linear'):
model_settings['compress_pos_emb'] = metadata[k]
elif k.endswith('rope.scaling.factor'):
model_settings['compress_pos_emb'] = metadata[k]
elif k.endswith('.block_count'): elif k.endswith('.block_count'):
model_settings['gpu_layers'] = -1 model_settings['gpu_layers'] = metadata[k] + 1
model_settings['max_gpu_layers'] = metadata[k] + 1 model_settings['max_gpu_layers'] = metadata[k] + 1
if 'tokenizer.chat_template' in metadata: if 'tokenizer.chat_template' in metadata:
template = metadata['tokenizer.chat_template'] template = metadata['tokenizer.chat_template']
if 'tokenizer.ggml.eos_token_id' in metadata:
eos_token = metadata['tokenizer.ggml.tokens'][metadata['tokenizer.ggml.eos_token_id']] eos_token = metadata['tokenizer.ggml.tokens'][metadata['tokenizer.ggml.eos_token_id']]
else:
eos_token = ""
if 'tokenizer.ggml.bos_token_id' in metadata: if 'tokenizer.ggml.bos_token_id' in metadata:
bos_token = metadata['tokenizer.ggml.tokens'][metadata['tokenizer.ggml.bos_token_id']] bos_token = metadata['tokenizer.ggml.tokens'][metadata['tokenizer.ggml.bos_token_id']]
else: else:
@ -110,6 +117,15 @@ def get_model_metadata(model):
model_settings['ctx_size'] = min(value, 8192) model_settings['ctx_size'] = min(value, 8192)
break break
if 'rope_theta' in metadata:
model_settings['rope_freq_base'] = metadata['rope_theta']
elif 'attn_config' in metadata and 'rope_theta' in metadata['attn_config']:
model_settings['rope_freq_base'] = metadata['attn_config']['rope_theta']
if 'rope_scaling' in metadata and isinstance(metadata['rope_scaling'], dict) and all(key in metadata['rope_scaling'] for key in ('type', 'factor')):
if metadata['rope_scaling']['type'] == 'linear':
model_settings['compress_pos_emb'] = metadata['rope_scaling']['factor']
if 'torch_dtype' in metadata and metadata['torch_dtype'] == 'bfloat16': if 'torch_dtype' in metadata and metadata['torch_dtype'] == 'bfloat16':
model_settings['bf16'] = True model_settings['bf16'] = True
@ -163,6 +179,10 @@ def get_model_metadata(model):
if 'instruction_template' not in model_settings: if 'instruction_template' not in model_settings:
model_settings['instruction_template'] = 'Alpaca' model_settings['instruction_template'] = 'Alpaca'
# Ignore rope_freq_base if set to the default value
if 'rope_freq_base' in model_settings and model_settings['rope_freq_base'] == 10000:
model_settings.pop('rope_freq_base')
# Apply user settings from user_data/models/config-user.yaml # Apply user settings from user_data/models/config-user.yaml
settings = shared.user_config settings = shared.user_config
for pat in settings: for pat in settings:
@ -176,7 +196,7 @@ def get_model_metadata(model):
# Load instruction template if defined by name rather than by value # Load instruction template if defined by name rather than by value
if model_settings['instruction_template'] != 'Custom (obtained from model metadata)': if model_settings['instruction_template'] != 'Custom (obtained from model metadata)':
model_settings['instruction_template_str'] = load_instruction_template(model_settings['instruction_template']) model_settings['instruction_template_str'] = chat.load_instruction_template(model_settings['instruction_template'])
return model_settings return model_settings
@ -193,8 +213,12 @@ def infer_loader(model_name, model_settings, hf_quant_method=None):
loader = 'llama.cpp' loader = 'llama.cpp'
elif hf_quant_method == 'exl3': elif hf_quant_method == 'exl3':
loader = 'ExLlamav3' loader = 'ExLlamav3'
elif hf_quant_method in ['exl2', 'gptq']:
loader = 'ExLlamav2_HF'
elif re.match(r'.*exl3', model_name.lower()): elif re.match(r'.*exl3', model_name.lower()):
loader = 'ExLlamav3' loader = 'ExLlamav3'
elif re.match(r'.*exl2', model_name.lower()):
loader = 'ExLlamav2_HF'
else: else:
loader = 'Transformers' loader = 'Transformers'
@ -205,7 +229,7 @@ def update_model_parameters(state, initial=False):
''' '''
UI: update the command-line arguments based on the interface values UI: update the command-line arguments based on the interface values
''' '''
elements = loaders.list_model_elements() # the names of the parameters elements = ui.list_model_elements() # the names of the parameters
for i, element in enumerate(elements): for i, element in enumerate(elements):
if element not in state: if element not in state:
@ -225,11 +249,10 @@ def apply_model_settings_to_state(model, state):
''' '''
UI: update the state variable with the model settings UI: update the state variable with the model settings
''' '''
import gradio as gr
model_settings = get_model_metadata(model) model_settings = get_model_metadata(model)
if 'loader' in model_settings: if 'loader' in model_settings:
loader = model_settings.pop('loader') loader = model_settings.pop('loader')
if not (loader == 'ExLlamav3_HF' and state['loader'] == 'ExLlamav3'): if not ((loader == 'ExLlamav2_HF' and state['loader'] == 'ExLlamav2') or (loader == 'ExLlamav3_HF' and state['loader'] == 'ExLlamav3')):
state['loader'] = loader state['loader'] = loader
for k in model_settings: for k in model_settings:
@ -238,18 +261,16 @@ def apply_model_settings_to_state(model, state):
# Handle GPU layers and VRAM update for llama.cpp # Handle GPU layers and VRAM update for llama.cpp
if state['loader'] == 'llama.cpp' and 'gpu_layers' in model_settings: if state['loader'] == 'llama.cpp' and 'gpu_layers' in model_settings:
gpu_layers = model_settings['gpu_layers'] # -1 (auto) by default, or user-saved value vram_info, gpu_layers_update = update_gpu_layers_and_vram(
max_layers = model_settings.get('max_gpu_layers', 256)
state['gpu_layers'] = gr.update(value=gpu_layers, maximum=max_layers)
vram_info = update_gpu_layers_and_vram(
state['loader'], state['loader'],
model, model,
gpu_layers, model_settings['gpu_layers'],
state['ctx_size'], state['ctx_size'],
state['cache_type'], state['cache_type'],
auto_adjust=True
) )
state['gpu_layers'] = gpu_layers_update
state['vram_info'] = vram_info state['vram_info'] = vram_info
return state return state
@ -268,7 +289,7 @@ def save_model_settings(model, state):
if model_regex not in user_config: if model_regex not in user_config:
user_config[model_regex] = {} user_config[model_regex] = {}
for k in loaders.list_model_elements(): for k in ui.list_model_elements():
if k == 'loader' or k in loaders.loaders_and_params[state['loader']]: if k == 'loader' or k in loaders.loaders_and_params[state['loader']]:
user_config[model_regex][k] = state[k] user_config[model_regex][k] = state[k]
@ -387,113 +408,120 @@ def estimate_vram(gguf_file, gpu_layers, ctx_size, cache_type):
return vram return vram
def update_gpu_layers_and_vram(loader, model, gpu_layers, ctx_size, cache_type): def get_nvidia_vram(return_free=True):
""" """
Compute the estimated VRAM usage for the given GPU layers and return Calculates VRAM statistics across all NVIDIA GPUs by parsing nvidia-smi output.
an HTML string for the UI display.
Args:
return_free (bool): If True, returns free VRAM. If False, returns total VRAM.
Returns:
int: Either the total free VRAM or total VRAM in MiB summed across all detected NVIDIA GPUs.
Returns -1 if nvidia-smi command fails (not found, error, etc.).
Returns 0 if nvidia-smi succeeds but no GPU memory info found.
""" """
if loader != 'llama.cpp' or model in ["None", None] or not model.endswith(".gguf") or gpu_layers < 0 or ctx_size == 0: try:
return f"<div id=\"vram-info\"'>Estimated VRAM to load the model: <span class=\"value\">auto</span></div>" # Execute nvidia-smi command
result = subprocess.run(
vram_usage = estimate_vram(model, gpu_layers, ctx_size, cache_type) ['nvidia-smi'],
return f"<div id=\"vram-info\"'>Estimated VRAM to load the model: <span class=\"value\">{vram_usage:.0f} MiB</span></div>" capture_output=True,
text=True,
check=False
def load_instruction_template(template):
if template == 'None':
return ''
for filepath in [shared.user_data_dir / 'instruction-templates' / f'{template}.yaml', shared.user_data_dir / 'instruction-templates' / 'Alpaca.yaml']:
if filepath.exists():
break
else:
return ''
with open(filepath, 'r', encoding='utf-8') as f:
file_contents = f.read()
data = yaml.safe_load(file_contents)
if 'instruction_template' in data:
return data['instruction_template']
else:
return _jinja_template_from_old_format(data)
def _jinja_template_from_old_format(params, verbose=False):
MASTER_TEMPLATE = """
{%- set ns = namespace(found=false) -%}
{%- for message in messages -%}
{%- if message['role'] == 'system' -%}
{%- set ns.found = true -%}
{%- endif -%}
{%- endfor -%}
{%- if not ns.found -%}
{{- '<|PRE-SYSTEM|>' + '<|SYSTEM-MESSAGE|>' + '<|POST-SYSTEM|>' -}}
{%- endif %}
{%- for message in messages %}
{%- if message['role'] == 'system' -%}
{{- '<|PRE-SYSTEM|>' + message['content'] + '<|POST-SYSTEM|>' -}}
{%- else -%}
{%- if message['role'] == 'user' -%}
{{-'<|PRE-USER|>' + message['content'] + '<|POST-USER|>'-}}
{%- else -%}
{{-'<|PRE-ASSISTANT|>' + message['content'] + '<|POST-ASSISTANT|>' -}}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{-'<|PRE-ASSISTANT-GENERATE|>'-}}
{%- endif -%}
"""
if 'context' in params and '<|system-message|>' in params['context']:
pre_system = params['context'].split('<|system-message|>')[0]
post_system = params['context'].split('<|system-message|>')[1]
else:
pre_system = ''
post_system = ''
pre_user = params['turn_template'].split('<|user-message|>')[0].replace('<|user|>', params['user'])
post_user = params['turn_template'].split('<|user-message|>')[1].split('<|bot|>')[0]
pre_assistant = '<|bot|>' + params['turn_template'].split('<|bot-message|>')[0].split('<|bot|>')[1]
pre_assistant = pre_assistant.replace('<|bot|>', params['bot'])
post_assistant = params['turn_template'].split('<|bot-message|>')[1]
def preprocess(string):
return string.replace('\n', '\\n').replace('\'', '\\\'')
pre_system = preprocess(pre_system)
post_system = preprocess(post_system)
pre_user = preprocess(pre_user)
post_user = preprocess(post_user)
pre_assistant = preprocess(pre_assistant)
post_assistant = preprocess(post_assistant)
if verbose:
print(
'\n',
repr(pre_system) + '\n',
repr(post_system) + '\n',
repr(pre_user) + '\n',
repr(post_user) + '\n',
repr(pre_assistant) + '\n',
repr(post_assistant) + '\n',
) )
result = MASTER_TEMPLATE # Check if nvidia-smi returned an error
if 'system_message' in params: if result.returncode != 0:
result = result.replace('<|SYSTEM-MESSAGE|>', preprocess(params['system_message'])) return -1
# Parse the output for memory usage patterns
output = result.stdout
# Find memory usage like "XXXXMiB / YYYYMiB"
# Captures used and total memory for each GPU
matches = re.findall(r"(\d+)\s*MiB\s*/\s*(\d+)\s*MiB", output)
if not matches:
# No GPUs found in expected format
return 0
total_vram_mib = 0
total_free_vram_mib = 0
for used_mem_str, total_mem_str in matches:
try:
used_mib = int(used_mem_str)
total_mib = int(total_mem_str)
total_vram_mib += total_mib
total_free_vram_mib += (total_mib - used_mib)
except ValueError:
# Skip malformed entries
pass
# Return either free or total VRAM based on the flag
return total_free_vram_mib if return_free else total_vram_mib
except FileNotFoundError:
# nvidia-smi not found (likely no NVIDIA drivers installed)
return -1
except Exception:
# Handle any other unexpected exceptions
return -1
def update_gpu_layers_and_vram(loader, model, gpu_layers, ctx_size, cache_type, auto_adjust=False, for_ui=True):
"""
Unified function to handle GPU layers and VRAM updates.
Args:
for_ui: If True, returns Gradio updates. If False, returns raw values.
Returns:
- If for_ui=True: (vram_info_update, gpu_layers_update) or just vram_info_update
- If for_ui=False: (vram_usage, adjusted_layers) or just vram_usage
"""
if loader != 'llama.cpp' or model in ["None", None] or not model.endswith(".gguf"):
vram_info = "<div id=\"vram-info\"'>Estimated VRAM to load the model:</div>"
if for_ui:
return (vram_info, gr.update()) if auto_adjust else vram_info
else: else:
result = result.replace('<|SYSTEM-MESSAGE|>', '') return (0, gpu_layers) if auto_adjust else 0
result = result.replace('<|PRE-SYSTEM|>', pre_system) # Get model settings including user preferences
result = result.replace('<|POST-SYSTEM|>', post_system) model_settings = get_model_metadata(model)
result = result.replace('<|PRE-USER|>', pre_user)
result = result.replace('<|POST-USER|>', post_user)
result = result.replace('<|PRE-ASSISTANT|>', pre_assistant)
result = result.replace('<|PRE-ASSISTANT-GENERATE|>', pre_assistant.rstrip(' '))
result = result.replace('<|POST-ASSISTANT|>', post_assistant)
result = result.strip() current_layers = gpu_layers
max_layers = model_settings.get('max_gpu_layers', 256)
return result if auto_adjust:
# Check if this is a user-saved setting
user_config = shared.user_config
model_regex = Path(model).name + '$'
has_user_setting = model_regex in user_config and 'gpu_layers' in user_config[model_regex]
if not has_user_setting:
# No user setting, auto-adjust from the maximum
current_layers = max_layers # Start from max
# Auto-adjust based on available/total VRAM
# If a model is loaded and it's for the UI, use the total VRAM to avoid confusion
return_free = False if (for_ui and shared.model_name not in [None, 'None']) else True
available_vram = get_nvidia_vram(return_free=return_free)
if available_vram > 0:
tolerance = 577
while current_layers > 0 and estimate_vram(model, current_layers, ctx_size, cache_type) > available_vram - tolerance:
current_layers -= 1
# Calculate VRAM with current layers
vram_usage = estimate_vram(model, current_layers, ctx_size, cache_type)
if for_ui:
vram_info = f"<div id=\"vram-info\"'>Estimated VRAM to load the model: <span class=\"value\">{vram_usage:.0f} MiB</span></div>"
if auto_adjust:
return vram_info, gr.update(value=current_layers, maximum=max_layers)
else:
return vram_info
else:
if auto_adjust:
return vram_usage, current_layers
else:
return vram_usage

View file

@ -1,28 +0,0 @@
import sys
from pathlib import Path
def resolve_user_data_dir():
"""
Resolve the user_data directory path. Order of precedence:
1. --user-data-dir CLI flag (pre-parsed from sys.argv before argparse)
2. In --portable mode, prefer ../user_data if it exists
3. Default: 'user_data'
"""
script_dir = Path(__file__).resolve().parent.parent
# Check sys.argv for --user-data-dir before argparse runs
for i, arg in enumerate(sys.argv):
if arg == '--user-data-dir' and i + 1 < len(sys.argv):
return Path(sys.argv[i + 1])
elif arg.startswith('--user-data-dir='):
return Path(arg.split('=', 1)[1])
# In portable mode, prefer ../user_data if it exists
is_portable = '--portable' in sys.argv
if is_portable:
parent_path = script_dir.parent / 'user_data'
if parent_path.exists():
return parent_path
return Path('user_data')

View file

@ -9,17 +9,17 @@ from modules.loaders import loaders_samplers
from modules.logging_colors import logger from modules.logging_colors import logger
default_preset_values = { def default_preset():
result = {
'temperature': 1, 'temperature': 1,
'dynatemp_low': 1, 'dynatemp_low': 1,
'dynatemp_high': 1, 'dynatemp_high': 1,
'dynatemp_exponent': 1, 'dynatemp_exponent': 1,
'smoothing_factor': 0, 'smoothing_factor': 0,
'smoothing_curve': 1, 'smoothing_curve': 1,
'min_p': 0,
'top_p': 1, 'top_p': 1,
'top_k': 0, 'top_k': 0,
'min_p': 0,
'top_n_sigma': 0,
'typical_p': 1, 'typical_p': 1,
'xtc_threshold': 0.1, 'xtc_threshold': 0.1,
'xtc_probability': 0, 'xtc_probability': 0,
@ -27,8 +27,7 @@ default_preset_values = {
'eta_cutoff': 0, 'eta_cutoff': 0,
'tfs': 1, 'tfs': 1,
'top_a': 0, 'top_a': 0,
'adaptive_target': 0, 'top_n_sigma': 0,
'adaptive_decay': 0.9,
'dry_multiplier': 0, 'dry_multiplier': 0,
'dry_allowed_length': 2, 'dry_allowed_length': 2,
'dry_base': 1.75, 'dry_base': 1.75,
@ -46,13 +45,9 @@ default_preset_values = {
'do_sample': True, 'do_sample': True,
'dynamic_temperature': False, 'dynamic_temperature': False,
'temperature_last': False, 'temperature_last': False,
'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ndry\ntop_n_sigma\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nadaptive_p\nmirostat\nxtc\nencoder_repetition_penalty\nno_repeat_ngram', 'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ndry\ntop_n_sigma\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat\nxtc\nencoder_repetition_penalty\nno_repeat_ngram',
'dry_sequence_breakers': '"\\n", ":", "\\"", "*"', 'dry_sequence_breakers': '"\\n", ":", "\\"", "*"',
} }
def default_preset():
result = dict(default_preset_values)
if shared.args.portable: if shared.args.portable:
samplers = result['sampler_priority'].split('\n') samplers = result['sampler_priority'].split('\n')
@ -69,7 +64,7 @@ def presets_params():
def load_preset(name, verbose=False): def load_preset(name, verbose=False):
generate_params = default_preset() generate_params = default_preset()
if name not in ['None', None, '']: if name not in ['None', None, '']:
path = shared.user_data_dir / 'presets' / f'{name}.yaml' path = Path(f'user_data/presets/{name}.yaml')
if path.exists(): if path.exists():
with open(path, 'r') as infile: with open(path, 'r') as infile:
preset = yaml.safe_load(infile) preset = yaml.safe_load(infile)

View file

@ -8,7 +8,7 @@ def load_prompt(fname):
if not fname: if not fname:
# Create new file # Create new file
new_name = utils.current_time() new_name = utils.current_time()
prompt_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt" prompt_path = Path("user_data/logs/notebook") / f"{new_name}.txt"
prompt_path.parent.mkdir(parents=True, exist_ok=True) prompt_path.parent.mkdir(parents=True, exist_ok=True)
initial_content = "In this story," initial_content = "In this story,"
prompt_path.write_text(initial_content, encoding='utf-8') prompt_path.write_text(initial_content, encoding='utf-8')
@ -18,7 +18,7 @@ def load_prompt(fname):
return initial_content return initial_content
file_path = shared.user_data_dir / 'logs' / 'notebook' / f'{fname}.txt' file_path = Path(f'user_data/logs/notebook/{fname}.txt')
if file_path.exists(): if file_path.exists():
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, 'r', encoding='utf-8') as f:
text = f.read() text = f.read()
@ -33,5 +33,5 @@ def count_tokens(text):
try: try:
tokens = get_encoded_length(text) tokens = get_encoded_length(text)
return str(tokens) return str(tokens)
except Exception: except:
return '0' return '0'

View file

@ -1,94 +0,0 @@
import html as html_module
# Thinking block format definitions: (start_tag, end_tag, content_start_tag)
# Use None for start_tag to match from beginning (end-only formats should be listed last)
THINKING_FORMATS = [
('<think>', '</think>', None),
('<|channel|>analysis<|message|>', '<|end|>', '<|channel|>final<|message|>'),
('<|channel|>commentary<|message|>', '<|end|>', '<|channel|>final<|message|>'),
('<seed:think>', '</seed:think>', None),
('<|think|>', '<|end|>', '<|content|>'), # Solar Open
# ('Thinking Process:', '</think>', None), # Qwen3.5 verbose thinking outside tags -- removed: too prone to false positives in streaming
(None, '</think>', None), # End-only variant (e.g., Qwen3-next)
]
def extract_reasoning(text, html_escaped=False):
"""Extract reasoning/thinking blocks from the beginning of a string.
When html_escaped=True, tags are HTML-escaped before searching
(for use on already-escaped UI strings).
Returns (reasoning_content, final_content) where reasoning_content is
None if no thinking block is found.
"""
if not text:
return None, text
esc = html_module.escape if html_escaped else lambda s: s
for start_tag, end_tag, content_tag in THINKING_FORMATS:
end_esc = esc(end_tag)
content_esc = esc(content_tag) if content_tag else None
if start_tag is None:
# End-only format: require end tag, start from beginning
end_pos = text.find(end_esc)
if end_pos == -1:
continue
thought_start = 0
else:
# Normal format: require start tag
start_esc = esc(start_tag)
start_pos = text.find(start_esc)
if start_pos == -1:
# During streaming, the start tag may be arriving partially.
# If the text is a prefix of a start tag, return empty content
# to prevent the partial tag from leaking.
stripped = text.strip()
if stripped and start_esc.startswith(stripped):
return '', ''
continue
thought_start = start_pos + len(start_esc)
end_pos = text.find(end_esc, thought_start)
if end_pos == -1:
# End tag missing - check if content tag can serve as fallback
if content_esc:
content_pos = text.find(content_esc, thought_start)
if content_pos != -1:
thought_end = content_pos
content_start = content_pos + len(content_esc)
else:
thought_end = len(text)
content_start = len(text)
else:
thought_end = len(text)
content_start = len(text)
else:
thought_end = end_pos
if content_esc:
content_pos = text.find(content_esc, end_pos)
if content_pos != -1:
content_start = content_pos + len(content_esc)
else:
# Content tag expected but not yet present (e.g. partial
# streaming) — suppress intermediate tags between end_tag
# and content_tag so they don't leak as content.
content_start = len(text)
else:
content_start = end_pos + len(end_esc)
return text[thought_start:thought_end], text[content_start:]
# Handle standalone GPT-OSS final channel marker without a preceding
# analysis/commentary block (the model skipped thinking entirely).
for marker in ['<|start|>assistant<|channel|>final<|message|>', '<|channel|>final<|message|>']:
marker_esc = esc(marker)
pos = text.find(marker_esc)
if pos != -1:
before = text[:pos].strip()
after = text[pos + len(marker_esc):]
return (before if before else None), after
return None, text

View file

@ -235,73 +235,6 @@ class TopNSigmaLogitsWarper(LogitsProcessor):
return scores return scores
class AdaptivePLogitsWarper(LogitsProcessor):
'''
Adaptive-p sampling. A stateful sampler that favors tokens near a target
probability, using an EMA-based control loop to adapt over time.
Matches the llama.cpp implementation from PR #17927.
'''
DISTRIBUTION_WIDTH = 0.3
PEAK_LOGIT_VALUE = 5.0
SHARPNESS = 10.0
INV_WIDTH = 1.0 / DISTRIBUTION_WIDTH
def __init__(self, adaptive_target, adaptive_decay, filter_value=-float("Inf"), min_tokens_to_keep=1):
self.target = adaptive_target
self.decay = min(adaptive_decay, 0.99)
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
# Initialize EMA at equilibrium (as if target was already achieved)
if self.decay < 1.0:
self.weighted_sum = self.target / (1.0 - self.decay)
self.total_weight = 1.0 / (1.0 - self.decay)
else:
self.weighted_sum = 0.0
self.total_weight = 0.0
def __call__(self, input_ids, scores):
logits = scores[0]
# Compute original probabilities (before transform)
probs = torch.softmax(logits, dim=-1)
# Compute adapted target using proportional control on the EMA
if self.total_weight > 0:
ema_avg = self.weighted_sum / self.total_weight
else:
ema_avg = self.target
adapted_target = max(0.0, min(1.0, 2.0 * self.target - ema_avg))
# Adaptive probability transform:
# quadratic near target for fine differentiation, transitioning
# to linear decay in the tails for proper suppression after softmax
dist = torch.abs((probs - adapted_target) * self.INV_WIDTH)
new_logits = self.PEAK_LOGIT_VALUE - self.SHARPNESS * dist * dist / (1.0 + dist)
# Preserve already-masked tokens (-inf logits from prior samplers)
new_logits = torch.where(torch.isfinite(logits), new_logits, logits)
# Softmax and sample from the transformed distribution
new_probs = torch.softmax(new_logits, dim=-1)
selected = torch.multinomial(new_probs, num_samples=1, replacement=True)
# Update EMA with the original probability of the selected token
original_prob = probs[selected[0]].item()
self.weighted_sum = original_prob + self.decay * self.weighted_sum
self.total_weight = 1.0 + self.decay * self.total_weight
# Mask all tokens except the selected one
indices_to_remove = torch.ones_like(scores[0], dtype=torch.bool)
indices_to_remove[selected[0]] = False
indices_to_remove = indices_to_remove.unsqueeze(0)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
# Exclude Top Choices (XTC) # Exclude Top Choices (XTC)
class XTCLogitsWarper(LogitsProcessor): class XTCLogitsWarper(LogitsProcessor):
def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")): def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")):
@ -642,15 +575,6 @@ def get_logits_processor_patch(self, **kwargs):
) )
) )
if generation_config.adaptive_target is not None and generation_config.adaptive_target > 0.0:
warpers_to_add.append(
AdaptivePLogitsWarper(
adaptive_target=generation_config.adaptive_target,
adaptive_decay=generation_config.adaptive_decay,
min_tokens_to_keep=min_tokens_to_keep
)
)
if generation_config.xtc_probability is not None and generation_config.xtc_probability > 0: if generation_config.xtc_probability is not None and generation_config.xtc_probability > 0:
warpers_to_add.append( warpers_to_add.append(
XTCLogitsWarper( XTCLogitsWarper(
@ -716,7 +640,6 @@ def get_logits_processor_patch(self, **kwargs):
'TemperatureLogitsWarperCustom': 'temperature', 'TemperatureLogitsWarperCustom': 'temperature',
'TopALogitsWarper': 'top_a', 'TopALogitsWarper': 'top_a',
'TopNSigmaLogitsWarper': 'top_n_sigma', 'TopNSigmaLogitsWarper': 'top_n_sigma',
'AdaptivePLogitsWarper': 'adaptive_p',
'TopKLogitsWarper': 'top_k', 'TopKLogitsWarper': 'top_k',
'TopPLogitsWarper': 'top_p', 'TopPLogitsWarper': 'top_p',
'TypicalLogitsWarper': 'typical_p', 'TypicalLogitsWarper': 'typical_p',
@ -765,8 +688,6 @@ def generation_config_init_patch(self, **kwargs):
self.tfs = kwargs.pop("tfs", 1.0) self.tfs = kwargs.pop("tfs", 1.0)
self.top_a = kwargs.pop("top_a", 0.0) self.top_a = kwargs.pop("top_a", 0.0)
self.top_n_sigma = kwargs.pop("top_n_sigma", 0.0) self.top_n_sigma = kwargs.pop("top_n_sigma", 0.0)
self.adaptive_target = kwargs.pop("adaptive_target", 0.0)
self.adaptive_decay = kwargs.pop("adaptive_decay", 0.9)
self.mirostat_mode = kwargs.pop("mirostat_mode", 0) self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1) self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
self.mirostat_tau = kwargs.pop("mirostat_tau", 5) self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
@ -780,7 +701,7 @@ def generation_config_init_patch(self, **kwargs):
self.xtc_threshold = kwargs.pop("xtc_threshold", 0.1) self.xtc_threshold = kwargs.pop("xtc_threshold", 0.1)
self.xtc_probability = kwargs.pop("xtc_probability", 0) self.xtc_probability = kwargs.pop("xtc_probability", 0)
self.temperature_last = kwargs.pop("temperature_last", False) self.temperature_last = kwargs.pop("temperature_last", False)
self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_n_sigma', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'adaptive_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram']) self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_n_sigma', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram'])
def hijack_samplers(): def hijack_samplers():

View file

@ -58,8 +58,9 @@ class SaneListIndentProcessor(ListIndentProcessor):
def test(self, parent: etree.Element, block: str) -> bool: def test(self, parent: etree.Element, block: str) -> bool:
return block.startswith(' ' * MIN_NESTED_LIST_INDENT) and \ return block.startswith(' ' * MIN_NESTED_LIST_INDENT) and \
not self.parser.state.isstate('detabbed') and \ not self.parser.state.isstate('detabbed') and \
(parent.tag in self.ITEM_TYPES or (len(parent) and parent[-1] is not None and (parent[-1].tag in (parent.tag in self.ITEM_TYPES or
self.LIST_TYPES))) (len(parent) and parent[-1] is not None and
(parent[-1].tag in self.LIST_TYPES)))
def get_level(self, parent: etree.Element, block: str) -> tuple[int, etree.Element]: def get_level(self, parent: etree.Element, block: str) -> tuple[int, etree.Element]:
""" Get level of indentation based on list level. """ """ Get level of indentation based on list level. """
@ -78,7 +79,8 @@ class SaneListIndentProcessor(ListIndentProcessor):
# Step through children of tree to find matching indent level. # Step through children of tree to find matching indent level.
while indent_level > level: while indent_level > level:
child = self.lastChild(parent) child = self.lastChild(parent)
if child is not None and (child.tag in self.LIST_TYPES or child.tag in self.ITEM_TYPES): if (child is not None and
(child.tag in self.LIST_TYPES or child.tag in self.ITEM_TYPES)):
if child.tag in self.LIST_TYPES: if child.tag in self.LIST_TYPES:
level += 1 level += 1
parent = child parent = child
@ -122,14 +124,16 @@ class SaneOListProcessor(OListProcessor):
def __init__(self, parser: blockparser.BlockParser): def __init__(self, parser: blockparser.BlockParser):
super().__init__(parser) super().__init__(parser)
max_list_start_indent = self.tab_length # This restriction stems from the 'CodeBlockProcessor' class,
# which automatically matches blocks with an indent = self.tab_length
max_list_start_indent = self.tab_length - 1
# Detect an item (e.g., `1. item`) # Detect an item (e.g., `1. item`)
self.RE = re.compile(r'^[ ]{0,%d}[\*_]{0,2}\d+\.[ ]+(.*)' % max_list_start_indent) self.RE = re.compile(r'^[ ]{0,%d}[\*_]{0,2}\d+\.[ ]+(.*)' % max_list_start_indent)
# Detect items on secondary lines. they can be of either list type. # Detect items on secondary lines. they can be of either list type.
self.CHILD_RE = re.compile(r'^[ ]{0,%d}([\*_]{0,2})((\d+\.))[ ]+(.*)' % (MIN_NESTED_LIST_INDENT - 1)) self.CHILD_RE = re.compile(r'^[ ]{0,%d}([\*_]{0,2})((\d+\.))[ ]+(.*)' % (MIN_NESTED_LIST_INDENT - 1))
# Detect indented (nested) items of either type # Detect indented (nested) items of either type
self.INDENT_RE = re.compile(r'^[ ]{%d,%d}[\*_]{0,2}((\d+\.)|[*+-])[ ]+.*' % self.INDENT_RE = re.compile(r'^[ ]{%d,%d}[\*_]{0,2}((\d+\.)|[*+-])[ ]+.*' %
(MIN_NESTED_LIST_INDENT, self.tab_length * 2)) (MIN_NESTED_LIST_INDENT, self.tab_length * 2 - 1))
def run(self, parent: etree.Element, blocks: list[str]) -> None: def run(self, parent: etree.Element, blocks: list[str]) -> None:
# Check for multiple items in one block. # Check for multiple items in one block.
@ -238,7 +242,7 @@ class SaneUListProcessor(SaneOListProcessor):
def __init__(self, parser: blockparser.BlockParser): def __init__(self, parser: blockparser.BlockParser):
super().__init__(parser) super().__init__(parser)
# Detect an item (e.g., `- item` or `+ item` or `* item`). # Detect an item (e.g., `- item` or `+ item` or `* item`).
max_list_start_indent = self.tab_length max_list_start_indent = self.tab_length - 1
self.RE = re.compile(r'^[ ]{0,%d}[*+-][ ]+(.*)' % max_list_start_indent) self.RE = re.compile(r'^[ ]{0,%d}[*+-][ ]+(.*)' % max_list_start_indent)
self.CHILD_RE = re.compile(r'^[ ]{0,%d}(([*+-]))[ ]+(.*)' % (MIN_NESTED_LIST_INDENT - 1)) self.CHILD_RE = re.compile(r'^[ ]{0,%d}(([*+-]))[ ]+(.*)' % (MIN_NESTED_LIST_INDENT - 1))
@ -271,7 +275,7 @@ class SaneParagraphProcessor(ParagraphProcessor):
def __init__(self, parser: BlockParser): def __init__(self, parser: BlockParser):
super().__init__(parser) super().__init__(parser)
max_list_start_indent = self.tab_length max_list_start_indent = self.tab_length - 1
self.LIST_RE = re.compile(r"\s{2}\n(\s{0,%d}[\d+*-])" % max_list_start_indent) self.LIST_RE = re.compile(r"\s{2}\n(\s{0,%d}[\d+*-])" % max_list_start_indent)
def run(self, parent: etree.Element, blocks: list[str]) -> None: def run(self, parent: etree.Element, blocks: list[str]) -> None:
@ -327,9 +331,6 @@ class SaneListExtension(Extension):
md.parser.blockprocessors.register(SaneUListProcessor(md.parser), 'ulist', 30) md.parser.blockprocessors.register(SaneUListProcessor(md.parser), 'ulist', 30)
md.parser.blockprocessors.register(SaneParagraphProcessor(md.parser), 'paragraph', 10) md.parser.blockprocessors.register(SaneParagraphProcessor(md.parser), 'paragraph', 10)
# Disable uncommon indented codeblocks (as opposed to fenced codeblocks delimited by "```")
md.parser.blockprocessors.deregister('code')
def makeExtension(**kwargs): # pragma: no cover def makeExtension(**kwargs): # pragma: no cover
return SaneListExtension(**kwargs) return SaneListExtension(**kwargs)

View file

@ -9,11 +9,7 @@ from pathlib import Path
import yaml import yaml
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.paths import resolve_user_data_dir from modules.presets import default_preset
from modules.presets import default_preset, default_preset_values
# Resolve user_data directory early (before argparse defaults are set)
user_data_dir = resolve_user_data_dir()
# Text model variables # Text model variables
model = None model = None
@ -46,12 +42,11 @@ parser = argparse.ArgumentParser(description="Text Generation Web UI", conflict_
# Basic settings # Basic settings
group = parser.add_argument_group('Basic settings') group = parser.add_argument_group('Basic settings')
group.add_argument('--user-data-dir', type=str, default=str(user_data_dir), help='Path to the user data directory. Default: auto-detected.') group.add_argument('--multi-user', action='store_true', help='Multi-user mode. Chat histories are not saved or automatically loaded. Warning: this is likely not safe for sharing publicly.')
group.add_argument('--multi-user', action='store_true', help='Multi-user mode. Chat histories are not saved or automatically loaded. Best suited for small trusted teams.')
group.add_argument('--model', type=str, help='Name of the model to load by default.') group.add_argument('--model', type=str, help='Name of the model to load by default.')
group.add_argument('--lora', type=str, nargs='+', help='The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.') group.add_argument('--lora', type=str, nargs='+', help='The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.')
group.add_argument('--model-dir', type=str, default=str(user_data_dir / 'models'), help='Path to directory with all the models.') group.add_argument('--model-dir', type=str, default='user_data/models', help='Path to directory with all the models.')
group.add_argument('--lora-dir', type=str, default=str(user_data_dir / 'loras'), help='Path to directory with all the loras.') group.add_argument('--lora-dir', type=str, default='user_data/loras', help='Path to directory with all the loras.')
group.add_argument('--model-menu', action='store_true', help='Show a model menu in the terminal when the web UI is first launched.') group.add_argument('--model-menu', action='store_true', help='Show a model menu in the terminal when the web UI is first launched.')
group.add_argument('--settings', type=str, help='Load the default interface settings from this yaml file. See user_data/settings-template.yaml for an example. If you create a file called user_data/settings.yaml, this file will be loaded by default without the need to use the --settings flag.') group.add_argument('--settings', type=str, help='Load the default interface settings from this yaml file. See user_data/settings-template.yaml for an example. If you create a file called user_data/settings.yaml, this file will be loaded by default without the need to use the --settings flag.')
group.add_argument('--extensions', type=str, nargs='+', help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.') group.add_argument('--extensions', type=str, nargs='+', help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
@ -61,7 +56,7 @@ group.add_argument('--idle-timeout', type=int, default=0, help='Unload model aft
# Image generation # Image generation
group = parser.add_argument_group('Image model') group = parser.add_argument_group('Image model')
group.add_argument('--image-model', type=str, help='Name of the image model to select on startup (overrides saved setting).') group.add_argument('--image-model', type=str, help='Name of the image model to select on startup (overrides saved setting).')
group.add_argument('--image-model-dir', type=str, default=str(user_data_dir / 'image_models'), help='Path to directory with all the image models.') group.add_argument('--image-model-dir', type=str, default='user_data/image_models', help='Path to directory with all the image models.')
group.add_argument('--image-dtype', type=str, default=None, choices=['bfloat16', 'float16'], help='Data type for image model.') group.add_argument('--image-dtype', type=str, default=None, choices=['bfloat16', 'float16'], help='Data type for image model.')
group.add_argument('--image-attn-backend', type=str, default=None, choices=['flash_attention_2', 'sdpa'], help='Attention backend for image model.') group.add_argument('--image-attn-backend', type=str, default=None, choices=['flash_attention_2', 'sdpa'], help='Attention backend for image model.')
group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.') group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.')
@ -72,12 +67,12 @@ group.add_argument('--image-quant', type=str, default=None,
# Model loader # Model loader
group = parser.add_argument_group('Model loader') group = parser.add_argument_group('Model loader')
group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav3, TensorRT-LLM.') group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav2_HF, ExLlamav2, TensorRT-LLM.')
# Cache # Cache
group = parser.add_argument_group('Context and cache') group = parser.add_argument_group('Context and cache')
group.add_argument('--ctx-size', '--n_ctx', '--max_seq_len', type=int, default=0, metavar='N', help='Context size in tokens. 0 = auto for llama.cpp (requires gpu-layers=-1), 8192 for other loaders.') group.add_argument('--ctx-size', '--n_ctx', '--max_seq_len', type=int, default=8192, metavar='N', help='Context size in tokens.')
group.add_argument('--cache-type', '--cache_type', type=str, default='fp16', metavar='N', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8).') group.add_argument('--cache-type', '--cache_type', type=str, default='fp16', metavar='N', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8).')
# Speculative decoding # Speculative decoding
group = parser.add_argument_group('Speculative decoding') group = parser.add_argument_group('Speculative decoding')
@ -86,14 +81,10 @@ group.add_argument('--draft-max', type=int, default=4, help='Number of tokens to
group.add_argument('--gpu-layers-draft', type=int, default=256, help='Number of layers to offload to the GPU for the draft model.') group.add_argument('--gpu-layers-draft', type=int, default=256, help='Number of layers to offload to the GPU for the draft model.')
group.add_argument('--device-draft', type=str, default=None, help='Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1') group.add_argument('--device-draft', type=str, default=None, help='Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1')
group.add_argument('--ctx-size-draft', type=int, default=0, help='Size of the prompt context for the draft model. If 0, uses the same as the main model.') group.add_argument('--ctx-size-draft', type=int, default=0, help='Size of the prompt context for the draft model. If 0, uses the same as the main model.')
group.add_argument('--spec-type', type=str, default='none', choices=['none', 'ngram-mod', 'ngram-simple', 'ngram-map-k', 'ngram-map-k4v', 'ngram-cache'], help='Draftless speculative decoding type. Recommended: ngram-mod.')
group.add_argument('--spec-ngram-size-n', type=int, default=24, help='N-gram lookup size for ngram speculative decoding.')
group.add_argument('--spec-ngram-size-m', type=int, default=48, help='Draft n-gram size for ngram speculative decoding.')
group.add_argument('--spec-ngram-min-hits', type=int, default=1, help='Minimum n-gram hits for ngram-map speculative decoding.')
# llama.cpp # llama.cpp
group = parser.add_argument_group('llama.cpp') group = parser.add_argument_group('llama.cpp')
group.add_argument('--gpu-layers', '--n-gpu-layers', type=int, default=-1, metavar='N', help='Number of layers to offload to the GPU. -1 = auto.') group.add_argument('--gpu-layers', '--n-gpu-layers', type=int, default=256, metavar='N', help='Number of layers to offload to the GPU.')
group.add_argument('--cpu-moe', action='store_true', help='Move the experts to the CPU (for MoE models).') group.add_argument('--cpu-moe', action='store_true', help='Move the experts to the CPU (for MoE models).')
group.add_argument('--mmproj', type=str, default=None, help='Path to the mmproj file for vision models.') group.add_argument('--mmproj', type=str, default=None, help='Path to the mmproj file for vision models.')
group.add_argument('--streaming-llm', action='store_true', help='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.') group.add_argument('--streaming-llm', action='store_true', help='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
@ -107,8 +98,6 @@ group.add_argument('--ubatch-size', type=int, default=1024, help='Maximum number
group.add_argument('--threads', type=int, default=0, help='Number of threads to use.') group.add_argument('--threads', type=int, default=0, help='Number of threads to use.')
group.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.') group.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.')
group.add_argument('--numa', action='store_true', help='Activate NUMA task allocation for llama.cpp.') group.add_argument('--numa', action='store_true', help='Activate NUMA task allocation for llama.cpp.')
group.add_argument('--parallel', type=int, default=1, help='Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set ctx_size to 32768.')
group.add_argument('--fit-target', type=str, default='512', help='Target VRAM margin per device for auto GPU layers, comma-separated list of values in MiB. A single value is broadcast across all devices.')
group.add_argument('--extra-flags', type=str, default=None, help='Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"') group.add_argument('--extra-flags', type=str, default=None, help='Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"')
# Transformers/Accelerate # Transformers/Accelerate
@ -116,7 +105,7 @@ group = parser.add_argument_group('Transformers/Accelerate')
group.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.') group.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
group.add_argument('--cpu-memory', type=float, default=0, help='Maximum CPU memory in GiB. Use this for CPU offloading.') group.add_argument('--cpu-memory', type=float, default=0, help='Maximum CPU memory in GiB. Use this for CPU offloading.')
group.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.') group.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
group.add_argument('--disk-cache-dir', type=str, default=str(user_data_dir / 'cache'), help='Directory to save the disk cache to.') group.add_argument('--disk-cache-dir', type=str, default='user_data/cache', help='Directory to save the disk cache to. Defaults to "user_data/cache".')
group.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision (using bitsandbytes).') group.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision (using bitsandbytes).')
group.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.') group.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
group.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces VRAM usage slightly, but it comes at a performance cost.') group.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces VRAM usage slightly, but it comes at a performance cost.')
@ -134,10 +123,34 @@ group.add_argument('--quant_type', type=str, default='nf4', help='quant_type for
# ExLlamaV3 # ExLlamaV3
group = parser.add_argument_group('ExLlamaV3') group = parser.add_argument_group('ExLlamaV3')
group.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.')
group.add_argument('--enable-tp', '--enable_tp', action='store_true', help='Enable Tensor Parallelism (TP) to split the model across GPUs.') group.add_argument('--enable-tp', '--enable_tp', action='store_true', help='Enable Tensor Parallelism (TP) to split the model across GPUs.')
group.add_argument('--tp-backend', type=str, default='native', help='The backend for tensor parallelism. Valid options: native, nccl. Default: native.') group.add_argument('--tp-backend', type=str, default='native', help='The backend for tensor parallelism. Valid options: native, nccl. Default: native.')
group.add_argument('--cfg-cache', action='store_true', help='Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.')
# ExLlamaV2
group = parser.add_argument_group('ExLlamaV2')
group.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.')
group.add_argument('--autosplit', action='store_true', help='Autosplit the model tensors across the available GPUs. This causes --gpu-split to be ignored.')
group.add_argument('--cfg-cache', action='store_true', help='ExLlamav2_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.')
group.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.')
group.add_argument('--no_xformers', action='store_true', help='Force xformers to not be used.')
group.add_argument('--no_sdpa', action='store_true', help='Force Torch SDPA to not be used.')
group.add_argument('--num_experts_per_token', type=int, default=2, metavar='N', help='Number of experts to use for generation. Applies to MoE models like Mixtral.')
# TensorRT-LLM
group = parser.add_argument_group('TensorRT-LLM')
group.add_argument('--cpp-runner', action='store_true', help='Use the ModelRunnerCpp runner, which is faster than the default ModelRunner but doesn\'t support streaming yet.')
# DeepSpeed
group = parser.add_argument_group('DeepSpeed')
group.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
group.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.')
group.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
# RoPE
group = parser.add_argument_group('RoPE')
group.add_argument('--alpha_value', type=float, default=1, help='Positional embeddings alpha factor for NTK RoPE scaling. Use either this or compress_pos_emb, not both.')
group.add_argument('--rope_freq_base', type=int, default=0, help='If greater than 0, will be used instead of alpha_value. Those two are related by rope_freq_base = 10000 * alpha_value ^ (64 / 63).')
group.add_argument('--compress_pos_emb', type=int, default=1, help="Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.")
# Gradio # Gradio
group = parser.add_argument_group('Gradio') group = parser.add_argument_group('Gradio')
@ -157,7 +170,7 @@ group.add_argument('--portable', action='store_true', help='Hide features not av
# API # API
group = parser.add_argument_group('API') group = parser.add_argument_group('API')
group.add_argument('--api', action='store_true', help='Enable the API extension.') group.add_argument('--api', action='store_true', help='Enable the API extension.')
group.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudflare.') group.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
group.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None) group.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None)
group.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.') group.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.')
group.add_argument('--api-key', type=str, default='', help='API authentication key.') group.add_argument('--api-key', type=str, default='', help='API authentication key.')
@ -166,53 +179,8 @@ group.add_argument('--api-enable-ipv6', action='store_true', help='Enable IPv6 f
group.add_argument('--api-disable-ipv4', action='store_true', help='Disable IPv4 for the API') group.add_argument('--api-disable-ipv4', action='store_true', help='Disable IPv4 for the API')
group.add_argument('--nowebui', action='store_true', help='Do not launch the Gradio UI. Useful for launching the API in standalone mode.') group.add_argument('--nowebui', action='store_true', help='Do not launch the Gradio UI. Useful for launching the API in standalone mode.')
# API generation defaults
_d = default_preset_values
group = parser.add_argument_group('API generation defaults')
group.add_argument('--temperature', type=float, default=_d['temperature'], metavar='N', help='Temperature')
group.add_argument('--dynatemp-low', type=float, default=_d['dynatemp_low'], metavar='N', help='Dynamic temperature low')
group.add_argument('--dynatemp-high', type=float, default=_d['dynatemp_high'], metavar='N', help='Dynamic temperature high')
group.add_argument('--dynatemp-exponent', type=float, default=_d['dynatemp_exponent'], metavar='N', help='Dynamic temperature exponent')
group.add_argument('--smoothing-factor', type=float, default=_d['smoothing_factor'], metavar='N', help='Smoothing factor')
group.add_argument('--smoothing-curve', type=float, default=_d['smoothing_curve'], metavar='N', help='Smoothing curve')
group.add_argument('--top-p', type=float, default=_d['top_p'], metavar='N', help='Top P')
group.add_argument('--top-k', type=int, default=_d['top_k'], metavar='N', help='Top K')
group.add_argument('--min-p', type=float, default=_d['min_p'], metavar='N', help='Min P')
group.add_argument('--top-n-sigma', type=float, default=_d['top_n_sigma'], metavar='N', help='Top N Sigma')
group.add_argument('--typical-p', type=float, default=_d['typical_p'], metavar='N', help='Typical P')
group.add_argument('--xtc-threshold', type=float, default=_d['xtc_threshold'], metavar='N', help='XTC threshold')
group.add_argument('--xtc-probability', type=float, default=_d['xtc_probability'], metavar='N', help='XTC probability')
group.add_argument('--epsilon-cutoff', type=float, default=_d['epsilon_cutoff'], metavar='N', help='Epsilon cutoff')
group.add_argument('--eta-cutoff', type=float, default=_d['eta_cutoff'], metavar='N', help='Eta cutoff')
group.add_argument('--tfs', type=float, default=_d['tfs'], metavar='N', help='TFS')
group.add_argument('--top-a', type=float, default=_d['top_a'], metavar='N', help='Top A')
group.add_argument('--adaptive-target', type=float, default=_d['adaptive_target'], metavar='N', help='Adaptive target')
group.add_argument('--adaptive-decay', type=float, default=_d['adaptive_decay'], metavar='N', help='Adaptive decay')
group.add_argument('--dry-multiplier', type=float, default=_d['dry_multiplier'], metavar='N', help='DRY multiplier')
group.add_argument('--dry-allowed-length', type=int, default=_d['dry_allowed_length'], metavar='N', help='DRY allowed length')
group.add_argument('--dry-base', type=float, default=_d['dry_base'], metavar='N', help='DRY base')
group.add_argument('--repetition-penalty', type=float, default=_d['repetition_penalty'], metavar='N', help='Repetition penalty')
group.add_argument('--frequency-penalty', type=float, default=_d['frequency_penalty'], metavar='N', help='Frequency penalty')
group.add_argument('--presence-penalty', type=float, default=_d['presence_penalty'], metavar='N', help='Presence penalty')
group.add_argument('--encoder-repetition-penalty', type=float, default=_d['encoder_repetition_penalty'], metavar='N', help='Encoder repetition penalty')
group.add_argument('--no-repeat-ngram-size', type=int, default=_d['no_repeat_ngram_size'], metavar='N', help='No repeat ngram size')
group.add_argument('--repetition-penalty-range', type=int, default=_d['repetition_penalty_range'], metavar='N', help='Repetition penalty range')
group.add_argument('--penalty-alpha', type=float, default=_d['penalty_alpha'], metavar='N', help='Penalty alpha')
group.add_argument('--guidance-scale', type=float, default=_d['guidance_scale'], metavar='N', help='Guidance scale')
group.add_argument('--mirostat-mode', type=int, default=_d['mirostat_mode'], metavar='N', help='Mirostat mode')
group.add_argument('--mirostat-tau', type=float, default=_d['mirostat_tau'], metavar='N', help='Mirostat tau')
group.add_argument('--mirostat-eta', type=float, default=_d['mirostat_eta'], metavar='N', help='Mirostat eta')
group.add_argument('--do-sample', action=argparse.BooleanOptionalAction, default=_d['do_sample'], help='Do sample')
group.add_argument('--dynamic-temperature', action=argparse.BooleanOptionalAction, default=_d['dynamic_temperature'], help='Dynamic temperature')
group.add_argument('--temperature-last', action=argparse.BooleanOptionalAction, default=_d['temperature_last'], help='Temperature last')
group.add_argument('--sampler-priority', type=str, default=_d['sampler_priority'], metavar='N', help='Sampler priority')
group.add_argument('--dry-sequence-breakers', type=str, default=_d['dry_sequence_breakers'], metavar='N', help='DRY sequence breakers')
group.add_argument('--enable-thinking', action=argparse.BooleanOptionalAction, default=True, help='Enable thinking')
group.add_argument('--reasoning-effort', type=str, default='medium', metavar='N', help='Reasoning effort')
group.add_argument('--chat-template-file', type=str, default=None, help='Path to a chat template file (.jinja, .jinja2, or .yaml) to use as the default instruction template for API requests. Overrides the model\'s built-in template.')
# Handle CMD_FLAGS.txt # Handle CMD_FLAGS.txt
cmd_flags_path = user_data_dir / "CMD_FLAGS.txt" cmd_flags_path = Path(__file__).parent.parent / "user_data" / "CMD_FLAGS.txt"
if cmd_flags_path.exists(): if cmd_flags_path.exists():
with cmd_flags_path.open('r', encoding='utf-8') as f: with cmd_flags_path.open('r', encoding='utf-8') as f:
cmd_flags = ' '.join( cmd_flags = ' '.join(
@ -227,7 +195,6 @@ if cmd_flags_path.exists():
args = parser.parse_args() args = parser.parse_args()
user_data_dir = Path(args.user_data_dir) # Update from parsed args (may differ from pre-parse)
original_args = copy.deepcopy(args) original_args = copy.deepcopy(args)
args_defaults = parser.parse_args([]) args_defaults = parser.parse_args([])
@ -257,9 +224,8 @@ settings = {
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>". Reply directly, without starting the reply with the character name.\n\n<|prompt|>', 'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>". Reply directly, without starting the reply with the character name.\n\n<|prompt|>',
'enable_web_search': False, 'enable_web_search': False,
'web_search_pages': 3, 'web_search_pages': 3,
'selected_tools': [],
'prompt-notebook': '', 'prompt-notebook': '',
'preset': 'Top-P' if (user_data_dir / 'presets/Top-P.yaml').exists() else None, 'preset': 'Qwen3 - Thinking' if Path('user_data/presets/Qwen3 - Thinking.yaml').exists() else None,
'max_new_tokens': 512, 'max_new_tokens': 512,
'max_new_tokens_min': 1, 'max_new_tokens_min': 1,
'max_new_tokens_max': 4096, 'max_new_tokens_max': 4096,
@ -284,7 +250,7 @@ settings = {
'include_past_attachments': True, 'include_past_attachments': True,
# Generation parameters - Curve shape # Generation parameters - Curve shape
'temperature': neutral_samplers['temperature'], 'temperature': 0.6,
'dynatemp_low': neutral_samplers['dynatemp_low'], 'dynatemp_low': neutral_samplers['dynatemp_low'],
'dynatemp_high': neutral_samplers['dynatemp_high'], 'dynatemp_high': neutral_samplers['dynatemp_high'],
'dynatemp_exponent': neutral_samplers['dynatemp_exponent'], 'dynatemp_exponent': neutral_samplers['dynatemp_exponent'],
@ -292,10 +258,9 @@ settings = {
'smoothing_curve': neutral_samplers['smoothing_curve'], 'smoothing_curve': neutral_samplers['smoothing_curve'],
# Generation parameters - Curve cutoff # Generation parameters - Curve cutoff
'top_p': 0.95,
'top_k': neutral_samplers['top_k'],
'min_p': neutral_samplers['min_p'], 'min_p': neutral_samplers['min_p'],
'top_n_sigma': neutral_samplers['top_n_sigma'], 'top_p': 0.95,
'top_k': 20,
'typical_p': neutral_samplers['typical_p'], 'typical_p': neutral_samplers['typical_p'],
'xtc_threshold': neutral_samplers['xtc_threshold'], 'xtc_threshold': neutral_samplers['xtc_threshold'],
'xtc_probability': neutral_samplers['xtc_probability'], 'xtc_probability': neutral_samplers['xtc_probability'],
@ -303,8 +268,7 @@ settings = {
'eta_cutoff': neutral_samplers['eta_cutoff'], 'eta_cutoff': neutral_samplers['eta_cutoff'],
'tfs': neutral_samplers['tfs'], 'tfs': neutral_samplers['tfs'],
'top_a': neutral_samplers['top_a'], 'top_a': neutral_samplers['top_a'],
'adaptive_target': neutral_samplers['adaptive_target'], 'top_n_sigma': neutral_samplers['top_n_sigma'],
'adaptive_decay': neutral_samplers['adaptive_decay'],
# Generation parameters - Repetition suppression # Generation parameters - Repetition suppression
'dry_multiplier': neutral_samplers['dry_multiplier'], 'dry_multiplier': neutral_samplers['dry_multiplier'],
@ -334,7 +298,6 @@ settings = {
# Character settings # Character settings
'character': 'Assistant', 'character': 'Assistant',
'user': 'Default',
'name1': 'You', 'name1': 'You',
'name2': 'AI', 'name2': 'AI',
'user_bio': '', 'user_bio': '',
@ -342,7 +305,7 @@ settings = {
'greeting': 'How can I help you today?', 'greeting': 'How can I help you today?',
'custom_system_message': '', 'custom_system_message': '',
'instruction_template_str': "{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if not ns.found -%}\n {{- '' + 'Below is an instruction that describes a task. Write a response that appropriately completes the request.' + '\\n\\n' -}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {{- '' + message['content'] + '\\n\\n' -}}\n {%- else -%}\n {%- if message['role'] == 'user' -%}\n {{-'### Instruction:\\n' + message['content'] + '\\n\\n'-}}\n {%- else -%}\n {{-'### Response:\\n' + message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{-'### Response:\\n'-}}\n{%- endif -%}", 'instruction_template_str': "{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if not ns.found -%}\n {{- '' + 'Below is an instruction that describes a task. Write a response that appropriately completes the request.' + '\\n\\n' -}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {{- '' + message['content'] + '\\n\\n' -}}\n {%- else -%}\n {%- if message['role'] == 'user' -%}\n {{-'### Instruction:\\n' + message['content'] + '\\n\\n'-}}\n {%- else -%}\n {{-'### Response:\\n' + message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{-'### Response:\\n'-}}\n{%- endif -%}",
'chat_template_str': "{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {%- if message['content'] -%}\n {{- message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- if user_bio -%}\n {{- user_bio + '\\n\\n' -}}\n {%- endif -%}\n {%- elif message['role'] == 'tool' -%}\n {{- '[Tool result: ' + message['content'] + ']\\n' -}}\n {%- elif message['role'] == 'user' -%}\n {{- name1 + ': ' + message['content'] + '\\n'-}}\n {%- elif message['tool_calls'] is defined and message['tool_calls'] -%}\n {%- for tc in message['tool_calls'] -%}\n {{- '[Calling: ' + tc['function']['name'] + '(' + tc['function']['arguments'] + ')]\\n' -}}\n {%- endfor -%}\n {%- else -%}\n {{- name2 + ': ' + message['content'] + '\\n' -}}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt %}\n {{- name2 + ':' -}}\n{%- endif %}", 'chat_template_str': "{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {%- if message['content'] -%}\n {{- message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- if user_bio -%}\n {{- user_bio + '\\n\\n' -}}\n {%- endif -%}\n {%- else -%}\n {%- if message['role'] == 'user' -%}\n {{- name1 + ': ' + message['content'] + '\\n'-}}\n {%- else -%}\n {{- name2 + ': ' + message['content'] + '\\n' -}}\n {%- endif -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt %}\n {{- name2 + ':' -}}\n{%- endif %}",
# Extensions # Extensions
'default_extensions': [], 'default_extensions': [],
@ -372,11 +335,6 @@ default_settings = copy.deepcopy(settings)
def do_cmd_flags_warnings(): def do_cmd_flags_warnings():
# Validate --chat-template-file
if args.chat_template_file and not Path(args.chat_template_file).is_file():
logger.error(f"--chat-template-file: file not found: {args.chat_template_file}")
sys.exit(1)
# Security warnings # Security warnings
if args.trust_remote_code: if args.trust_remote_code:
logger.warning( logger.warning(
@ -390,16 +348,9 @@ def do_cmd_flags_warnings():
if args.share: if args.share:
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.") logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)): if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)):
logger.warning("You are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.") logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
if args.multi_user: if args.multi_user:
logger.warning( logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.')
'Multi-user mode is enabled. Known limitations:'
'\n- The Stop button stops generation for all users, not just you.'
'\n- Chat history is not saved and will be lost on page refresh.'
'\n- Only one user can generate at a time unless using a parallel-capable backend (e.g. llama.cpp with --parallel N for N > 1, or ExLlamaV3).'
'\n\nThis mode works best for small trusted teams.'
'\n\nDo not expose publicly. Grayed-out actions can easily be bypassed client-side.\n'
)
def apply_image_model_cli_overrides(): def apply_image_model_cli_overrides():
@ -427,6 +378,10 @@ def fix_loader_name(name):
return 'llama.cpp' return 'llama.cpp'
elif name in ['transformers', 'huggingface', 'hf', 'hugging_face', 'hugging face']: elif name in ['transformers', 'huggingface', 'hf', 'hugging_face', 'hugging face']:
return 'Transformers' return 'Transformers'
elif name in ['exllamav2', 'exllama-v2', 'ex_llama-v2', 'exlamav2', 'exlama-v2', 'exllama2', 'exllama-2']:
return 'ExLlamav2'
elif name in ['exllamav2-hf', 'exllamav2_hf', 'exllama-v2-hf', 'exllama_v2_hf', 'exllama-v2_hf', 'exllama2-hf', 'exllama2_hf', 'exllama-2-hf', 'exllama_2_hf', 'exllama-2_hf']:
return 'ExLlamav2_HF'
elif name in ['exllamav3-hf', 'exllamav3_hf', 'exllama-v3-hf', 'exllama_v3_hf', 'exllama-v3_hf', 'exllama3-hf', 'exllama3_hf', 'exllama-3-hf', 'exllama_3_hf', 'exllama-3_hf']: elif name in ['exllamav3-hf', 'exllamav3_hf', 'exllama-v3-hf', 'exllama_v3_hf', 'exllama-v3_hf', 'exllama3-hf', 'exllama3_hf', 'exllama-3-hf', 'exllama_3_hf', 'exllama-3_hf']:
return 'ExLlamav3_HF' return 'ExLlamav3_HF'
elif name in ['exllamav3']: elif name in ['exllamav3']:

View file

@ -1,10 +1,15 @@
from pathlib import Path from pathlib import Path
from tensorrt_llm._tensorrt_engine import LLM import tensorrt_llm
from tensorrt_llm.llmapi import SamplingParams import torch
from tensorrt_llm.runtime import ModelRunner, ModelRunnerCpp
from modules import shared from modules import shared
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.text_generation import (
get_max_prompt_length,
get_reply_from_output_ids
)
class TensorRTLLMModel: class TensorRTLLMModel:
@ -12,50 +17,110 @@ class TensorRTLLMModel:
pass pass
@classmethod @classmethod
def from_pretrained(cls, path_to_model): def from_pretrained(self, path_to_model):
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
llm = LLM( path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
model=str(path_to_model), runtime_rank = tensorrt_llm.mpi_rank()
skip_tokenizer_init=False,
# Define model settings
runner_kwargs = dict(
engine_dir=str(path_to_model),
lora_dir=None,
rank=runtime_rank,
debug_mode=False,
lora_ckpt_source="hf",
) )
result = cls() if shared.args.cpp_runner:
result.llm = llm logger.info("TensorRT-LLM: Using \"ModelRunnerCpp\"")
result.tokenizer = llm.tokenizer runner_kwargs.update(
max_batch_size=1,
max_input_len=shared.args.ctx_size - 512,
max_output_len=512,
max_beam_width=1,
max_attention_window_size=None,
sink_token_length=None,
)
else:
logger.info("TensorRT-LLM: Using \"ModelRunner\"")
# Load the model
runner_cls = ModelRunnerCpp if shared.args.cpp_runner else ModelRunner
runner = runner_cls.from_dir(**runner_kwargs)
result = self()
result.model = runner
result.runtime_rank = runtime_rank
return result return result
def generate_with_streaming(self, prompt, state): def generate_with_streaming(self, prompt, state):
sampling_params = SamplingParams( batch_input_ids = []
max_tokens=state['max_new_tokens'] if not state['auto_max_new_tokens'] input_ids = shared.tokenizer.encode(
else state['truncation_length'] - len(shared.tokenizer.encode(prompt)), prompt,
end_id=shared.tokenizer.eos_token_id, add_special_tokens=True,
truncation=False,
)
input_ids = torch.tensor(input_ids, dtype=torch.int32)
input_ids = input_ids[-get_max_prompt_length(state):] # Apply truncation_length
batch_input_ids.append(input_ids)
if shared.args.cpp_runner:
max_new_tokens = min(512, state['max_new_tokens'])
elif state['auto_max_new_tokens']:
max_new_tokens = state['truncation_length'] - input_ids.shape[-1]
else:
max_new_tokens = state['max_new_tokens']
with torch.no_grad():
generator = self.model.generate(
batch_input_ids,
max_new_tokens=max_new_tokens,
max_attention_window_size=None,
sink_token_length=None,
end_id=shared.tokenizer.eos_token_id if not state['ban_eos_token'] else -1,
pad_id=shared.tokenizer.pad_token_id or shared.tokenizer.eos_token_id,
temperature=state['temperature'], temperature=state['temperature'],
top_k=state['top_k'], top_k=state['top_k'],
top_p=state['top_p'], top_p=state['top_p'],
min_p=state['min_p'], num_beams=1,
length_penalty=1.0,
repetition_penalty=state['repetition_penalty'], repetition_penalty=state['repetition_penalty'],
presence_penalty=state['presence_penalty'], presence_penalty=state['presence_penalty'],
frequency_penalty=state['frequency_penalty'], frequency_penalty=state['frequency_penalty'],
no_repeat_ngram_size=state['no_repeat_ngram_size'] if state['no_repeat_ngram_size'] > 0 else None, stop_words_list=None,
seed=state['seed'], bad_words_list=None,
ignore_eos=state['ban_eos_token'], lora_uids=None,
add_special_tokens=state['add_bos_token'], prompt_table_path=None,
skip_special_tokens=state['skip_special_tokens'], prompt_tasks=None,
streaming=not shared.args.cpp_runner,
output_sequence_lengths=True,
return_dict=True,
medusa_choices=None
) )
stop_event = state.get('stop_event') torch.cuda.synchronize()
result = self.llm.generate_async(prompt, sampling_params=sampling_params, streaming=True)
cumulative_reply = '' cumulative_reply = ''
for output in result: starting_from = batch_input_ids[0].shape[-1]
if shared.stop_everything or (stop_event and stop_event.is_set()):
result.abort() if shared.args.cpp_runner:
sequence_length = generator['sequence_lengths'][0].item()
output_ids = generator['output_ids'][0][0][:sequence_length].tolist()
cumulative_reply += get_reply_from_output_ids(output_ids, state, starting_from=starting_from)
starting_from = sequence_length
yield cumulative_reply
else:
for curr_outputs in generator:
if shared.stop_everything:
break break
text_diff = output.outputs[0].text_diff sequence_length = curr_outputs['sequence_lengths'][0].item()
if text_diff: output_ids = curr_outputs['output_ids'][0][0][:sequence_length].tolist()
cumulative_reply += text_diff
cumulative_reply += get_reply_from_output_ids(output_ids, state, starting_from=starting_from)
starting_from = sequence_length
yield cumulative_reply yield cumulative_reply
def generate(self, prompt, state): def generate(self, prompt, state):
@ -64,8 +129,3 @@ class TensorRTLLMModel:
pass pass
return output return output
def unload(self):
if hasattr(self, 'llm') and self.llm is not None:
self.llm.shutdown()
self.llm = None

View file

@ -22,22 +22,12 @@ def generate_reply(*args, **kwargs):
from modules.models import load_model from modules.models import load_model
shared.model, shared.tokenizer = load_model(shared.model_name) shared.model, shared.tokenizer = load_model(shared.model_name)
state = args[1] if len(args) > 1 else kwargs.get('state', {})
use_parallel = (
state.get('stop_event') is not None
and shared.model.__class__.__name__ in ['Exllamav3Model', 'LlamaServer', 'TensorRTLLMModel']
and (shared.model.__class__.__name__ != 'LlamaServer' or shared.args.parallel > 1)
)
if not use_parallel:
shared.generation_lock.acquire() shared.generation_lock.acquire()
try: try:
for result in _generate_reply(*args, **kwargs): for result in _generate_reply(*args, **kwargs):
yield result yield result
finally: finally:
models.last_generation_time = time.time() models.last_generation_time = time.time()
if not use_parallel:
shared.generation_lock.release() shared.generation_lock.release()
@ -50,7 +40,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
yield '' yield ''
return return
if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav3Model', 'TensorRTLLMModel']: if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'Exllamav3Model', 'TensorRTLLMModel']:
generate_func = generate_reply_custom generate_func = generate_reply_custom
else: else:
generate_func = generate_reply_HF generate_func = generate_reply_HF
@ -78,13 +68,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
reply = '' reply = ''
is_stream = state['stream'] is_stream = state['stream']
if len(all_stop_strings) > 0 and not state['stream']: if len(all_stop_strings) > 0 and not state['stream']:
original_logits_processor = state.get('logits_processor')
stop_event_ref = state.pop('stop_event', None)
state = copy.deepcopy(state) state = copy.deepcopy(state)
if stop_event_ref is not None:
state['stop_event'] = stop_event_ref
if original_logits_processor is not None:
state['logits_processor'] = original_logits_processor
state['stream'] = True state['stream'] = True
# Generate # Generate
@ -115,8 +99,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
yield reply yield reply
last_update = time.monotonic() last_update = time.monotonic()
stop_event = state.get('stop_event') if stop_found or (state['max_tokens_second'] > 0 and shared.stop_everything):
if stop_found or shared.stop_everything or (stop_event and stop_event.is_set()):
break break
if not is_chat: if not is_chat:
@ -145,9 +128,9 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
from modules.torch_utils import get_device from modules.torch_utils import get_device
if shared.model.__class__.__name__ in ['Exllamav3Model', 'TensorRTLLMModel']: if shared.model.__class__.__name__ in ['Exllamav2Model', 'Exllamav3Model', 'TensorRTLLMModel']:
input_ids = shared.tokenizer.encode(str(prompt)) input_ids = shared.tokenizer.encode(str(prompt))
if shared.model.__class__.__name__ not in ['Exllamav3Model']: if shared.model.__class__.__name__ not in ['Exllamav2Model', 'Exllamav3Model']:
input_ids = np.array(input_ids).reshape(1, len(input_ids)) input_ids = np.array(input_ids).reshape(1, len(input_ids))
else: else:
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens) input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
@ -165,7 +148,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
if truncation_length is not None: if truncation_length is not None:
input_ids = input_ids[:, -truncation_length:] input_ids = input_ids[:, -truncation_length:]
if shared.model.__class__.__name__ in ['Exllamav3Model', 'TensorRTLLMModel'] or shared.args.cpu: if shared.model.__class__.__name__ in ['Exllamav2Model', 'Exllamav3Model', 'TensorRTLLMModel'] or shared.args.cpu:
return input_ids return input_ids
else: else:
device = get_device() device = get_device()
@ -334,8 +317,6 @@ def generate_reply_HF(question, original_question, state, stopping_strings=None,
'tfs', 'tfs',
'top_a', 'top_a',
'top_n_sigma', 'top_n_sigma',
'adaptive_target',
'adaptive_decay',
'dry_multiplier', 'dry_multiplier',
'dry_allowed_length', 'dry_allowed_length',
'dry_base', 'dry_base',
@ -378,7 +359,7 @@ def generate_reply_HF(question, original_question, state, stopping_strings=None,
generate_params['sampler_priority'] = [x.strip() for x in state['sampler_priority'].replace('\n', ',').split(',') if x.strip()] generate_params['sampler_priority'] = [x.strip() for x in state['sampler_priority'].replace('\n', ',').split(',') if x.strip()]
if state['custom_token_bans']: if state['custom_token_bans']:
to_ban = [int(x.strip()) for x in state['custom_token_bans'].split(',') if x.strip()] to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
if len(to_ban) > 0: if len(to_ban) > 0:
if generate_params.get('suppress_tokens', None): if generate_params.get('suppress_tokens', None):
generate_params['suppress_tokens'] += to_ban generate_params['suppress_tokens'] += to_ban
@ -389,6 +370,8 @@ def generate_reply_HF(question, original_question, state, stopping_strings=None,
generate_params['negative_prompt_ids'] = encode(state['negative_prompt']) generate_params['negative_prompt_ids'] = encode(state['negative_prompt'])
generate_params.update({'use_cache': not shared.args.no_cache}) generate_params.update({'use_cache': not shared.args.no_cache})
if shared.args.deepspeed:
generate_params.update({'synced_gpus': True})
# Encode the input # Encode the input
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state)) input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
@ -491,10 +474,7 @@ def generate_reply_custom(question, original_question, state, stopping_strings=N
For models that do not use the transformers library for sampling For models that do not use the transformers library for sampling
""" """
stop_event_ref = state.pop('stop_event', None)
state = copy.deepcopy(state) state = copy.deepcopy(state)
if stop_event_ref is not None:
state['stop_event'] = stop_event_ref
state['seed'] = set_manual_seed(state['seed']) state['seed'] = set_manual_seed(state['seed'])
t0 = time.time() t0 = time.time()
reply = '' reply = ''

View file

@ -1,667 +0,0 @@
import json
import random
import re
def get_tool_call_id() -> str:
letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789"
b = [random.choice(letter_bytes) for _ in range(8)]
return "call_" + "".join(b).lower()
# All known opening markers for tool calls across model formats.
TOOL_CALL_OPENING_MARKERS = [
'<tool_call>',
'<function_call>',
'<minimax:tool_call>',
'<|tool_call_begin|>',
'<|tool_calls_section_begin|>',
'<tool▁call▁begin>',
'<tool▁calls▁begin>',
'[TOOL_CALLS]',
'to=functions.',
'<|channel|>commentary',
]
def streaming_tool_buffer_check(text, markers=None, tool_names=None, check_bare_names=False):
'''
Check whether streaming output should be withheld because it may
contain tool-call markup.
Args:
text: Full accumulated internal text.
markers: Template-specific markers for partial-prefix matching.
If None, falls back to TOOL_CALL_OPENING_MARKERS.
tool_names: List of tool function names.
check_bare_names: Whether to do partial-prefix matching on tool
names (for models with unknown template format).
'''
# Full marker found in text → buffer permanently.
# Always checks ALL known markers regardless of template (cheap safety net).
for marker in TOOL_CALL_OPENING_MARKERS:
if marker in text:
return True
# Bare function-name full match: "get_weather{...}" or "get_weather {...}"
if tool_names:
for name in tool_names:
if name + '{' in text or name + ' {' in text:
return True
# Partial-prefix matching: only for template-specific markers.
for marker in (markers if markers is not None else TOOL_CALL_OPENING_MARKERS):
for prefix_len in range(min(len(marker) - 1, len(text)), 0, -1):
if text.endswith(marker[:prefix_len]):
return True
# Bare-name partial matching: only when template format is unknown.
if check_bare_names and tool_names:
for name in tool_names:
if text.endswith(name):
return True
for prefix_len in range(min(len(name) - 1, len(text)), 0, -1):
if text.endswith(name[:prefix_len]):
return True
return False
def check_and_sanitize_tool_call_candidate(candidate_dict: dict, tool_names: list[str]):
# check if property 'function' exists and is a dictionary, otherwise adapt dict
if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str):
candidate_dict = {"type": "function", "function": candidate_dict}
if 'function' in candidate_dict and isinstance(candidate_dict['function'], str):
candidate_dict['name'] = candidate_dict['function']
del candidate_dict['function']
candidate_dict = {"type": "function", "function": candidate_dict}
if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict):
# check if 'name' exists within 'function' and is part of known tools
if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names:
candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value
# map property 'parameters' used by some older models to 'arguments'
if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]:
candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"]
del candidate_dict["function"]["parameters"]
return candidate_dict
return None
def _extract_balanced_json(text: str, start: int) -> str | None:
"""Extract a balanced JSON object from text starting at the given position.
Walks through the string tracking brace depth and string boundaries
to correctly handle arbitrary nesting levels.
"""
if start >= len(text) or text[start] != '{':
return None
depth = 0
in_string = False
escape_next = False
for i in range(start, len(text)):
c = text[i]
if escape_next:
escape_next = False
continue
if c == '\\' and in_string:
escape_next = True
continue
if c == '"':
in_string = not in_string
continue
if in_string:
continue
if c == '{':
depth += 1
elif c == '}':
depth -= 1
if depth == 0:
return text[start:i + 1]
return None
def _parse_channel_tool_calls(answer: str, tool_names: list[str]):
"""Parse channel-based tool calls used by GPT-OSS and similar models.
Format:
<|start|>assistant to=functions.func_name<|channel|>commentary json<|message|>{"arg": "value"}
or:
<|channel|>commentary to=functions.func_name <|constrain|>json<|message|>{"arg": "value"}
"""
matches = []
start_pos = None
# Pattern 1: to=functions.NAME before <|channel|> (GPT-OSS primary format)
# Pattern 2: to=functions.NAME after <|channel|> (alternative format)
patterns = [
r'to=functions\.([^<\s]+)\s*<\|channel\|>[^<]*<\|message\|>',
r'<\|channel\|>\w+ to=functions\.([^<\s]+).*?<\|message\|>',
]
for pattern in patterns:
for m in re.finditer(pattern, answer):
func_name = m.group(1).strip()
if func_name not in tool_names:
continue
json_str = _extract_balanced_json(answer, m.end())
if json_str is None:
continue
try:
arguments = json.loads(json_str)
if start_pos is None:
prefix = answer.rfind('<|start|>assistant', 0, m.start())
start_pos = prefix if prefix != -1 else m.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
except json.JSONDecodeError:
pass
if matches:
break
return matches, start_pos
def _parse_mistral_token_tool_calls(answer: str, tool_names: list[str]):
"""Parse Mistral/Devstral-style tool calls with [TOOL_CALLS] and [ARGS] special tokens.
Format:
[TOOL_CALLS]func_name[ARGS]{"arg": "value"}
"""
matches = []
start_pos = None
for m in re.finditer(
r'\[TOOL_CALLS\]\s*(\S+?)\s*\[ARGS\]\s*',
answer
):
func_name = m.group(1).strip()
if func_name not in tool_names:
continue
json_str = _extract_balanced_json(answer, m.end())
if json_str is None:
continue
try:
arguments = json.loads(json_str)
if start_pos is None:
start_pos = m.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
except json.JSONDecodeError:
pass
return matches, start_pos
def _parse_bare_name_tool_calls(answer: str, tool_names: list[str]):
"""Parse bare function-name style tool calls used by Mistral and similar models.
Format:
functionName{"arg": "value"}
Multiple calls are concatenated directly or separated by whitespace.
"""
matches = []
start_pos = None
# Match tool name followed by opening brace, then extract balanced JSON
escaped_names = [re.escape(name) for name in tool_names]
pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{'
for match in re.finditer(pattern, answer):
text = match.group(0)
name = None
for n in tool_names:
if text.startswith(n):
name = n
break
if not name:
continue
brace_start = match.end() - 1
json_str = _extract_balanced_json(answer, brace_start)
if json_str is None:
continue
try:
arguments = json.loads(json_str)
if start_pos is None:
start_pos = match.start()
matches.append({
"type": "function",
"function": {
"name": name,
"arguments": arguments
}
})
except json.JSONDecodeError:
pass
return matches, start_pos
def _parse_xml_param_tool_calls(answer: str, tool_names: list[str]):
"""Parse XML-parameter style tool calls used by Qwen3.5 and similar models.
Format:
<tool_call>
<function=function_name>
<parameter=param_name>value</parameter>
</function>
</tool_call>
"""
matches = []
start_pos = None
for tc_match in re.finditer(r'<tool_call>\s*(.*?)\s*</tool_call>', answer, re.DOTALL):
tc_content = tc_match.group(1)
func_match = re.search(r'<function=([^>]+)>', tc_content)
if not func_match:
continue
func_name = func_match.group(1).strip()
if func_name not in tool_names:
continue
arguments = {}
for param_match in re.finditer(r'<parameter=([^>]+)>\s*(.*?)\s*</parameter>', tc_content, re.DOTALL):
param_name = param_match.group(1).strip()
param_value = param_match.group(2).strip()
try:
param_value = json.loads(param_value)
except (json.JSONDecodeError, ValueError):
pass # keep as string
arguments[param_name] = param_value
if start_pos is None:
start_pos = tc_match.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
return matches, start_pos
def _parse_kimi_tool_calls(answer: str, tool_names: list[str]):
"""Parse Kimi-K2-style tool calls using pipe-delimited tokens.
Format:
<|tool_calls_section_begin|>
<|tool_call_begin|>functions.func_name:index<|tool_call_argument_begin|>{"arg": "value"}<|tool_call_end|>
<|tool_calls_section_end|>
"""
matches = []
start_pos = None
for m in re.finditer(
r'<\|tool_call_begin\|>\s*(?:functions\.)?(\S+?)(?::\d+)?\s*<\|tool_call_argument_begin\|>\s*',
answer
):
func_name = m.group(1).strip()
if func_name not in tool_names:
continue
json_str = _extract_balanced_json(answer, m.end())
if json_str is None:
continue
try:
arguments = json.loads(json_str)
if start_pos is None:
# Check for section begin marker before the call marker
section = answer.rfind('<|tool_calls_section_begin|>', 0, m.start())
start_pos = section if section != -1 else m.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
except json.JSONDecodeError:
pass
return matches, start_pos
def _parse_minimax_tool_calls(answer: str, tool_names: list[str]):
"""Parse MiniMax-style tool calls using invoke/parameter XML tags.
Format:
<minimax:tool_call>
<invoke name="function_name">
<parameter name="param_name">value</parameter>
</invoke>
</minimax:tool_call>
"""
matches = []
start_pos = None
for tc_match in re.finditer(r'<minimax:tool_call>\s*(.*?)\s*</minimax:tool_call>', answer, re.DOTALL):
tc_content = tc_match.group(1)
# Split on <invoke> to handle multiple parallel calls in one block
for invoke_match in re.finditer(r'<invoke\s+name="([^"]+)">(.*?)</invoke>', tc_content, re.DOTALL):
func_name = invoke_match.group(1).strip()
if func_name not in tool_names:
continue
invoke_body = invoke_match.group(2)
arguments = {}
for param_match in re.finditer(r'<parameter\s+name="([^"]+)">\s*(.*?)\s*</parameter>', invoke_body, re.DOTALL):
param_name = param_match.group(1).strip()
param_value = param_match.group(2).strip()
try:
param_value = json.loads(param_value)
except (json.JSONDecodeError, ValueError):
pass # keep as string
arguments[param_name] = param_value
if start_pos is None:
start_pos = tc_match.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
return matches, start_pos
def _parse_deep_seek_tool_calls(answer: str, tool_names: list[str]):
"""Parse DeepSeek-style tool calls using fullwidth Unicode token delimiters.
Format:
<toolcallsbegin><toolcallbegin>func_name<toolsep>{"arg": "value"}<toolcallend><toolcallsend>
"""
matches = []
start_pos = None
for m in re.finditer(
r'<tool▁call▁begin>\s*(\S+?)\s*<tool▁sep>\s*',
answer
):
func_name = m.group(1).strip()
if func_name not in tool_names:
continue
json_str = _extract_balanced_json(answer, m.end())
if json_str is None:
continue
try:
arguments = json.loads(json_str)
if start_pos is None:
# Check for section begin marker before the call marker
section = answer.rfind('<tool▁calls▁begin>', 0, m.start())
start_pos = section if section != -1 else m.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
except json.JSONDecodeError:
pass
return matches, start_pos
def _parse_glm_tool_calls(answer: str, tool_names: list[str]):
"""Parse GLM-style tool calls using arg_key/arg_value XML pairs.
Format:
<tool_call>function_name
<arg_key>key1</arg_key>
<arg_value>value1</arg_value>
</tool_call>
"""
matches = []
start_pos = None
for tc_match in re.finditer(r'<tool_call>\s*(.*?)\s*</tool_call>', answer, re.DOTALL):
tc_content = tc_match.group(1)
# First non-tag text is the function name
name_match = re.match(r'([^<\s]+)', tc_content.strip())
if not name_match:
continue
func_name = name_match.group(1).strip()
if func_name not in tool_names:
continue
# Extract arg_key/arg_value pairs
keys = [k.group(1).strip() for k in re.finditer(r'<arg_key>\s*(.*?)\s*</arg_key>', tc_content, re.DOTALL)]
vals = [v.group(1).strip() for v in re.finditer(r'<arg_value>\s*(.*?)\s*</arg_value>', tc_content, re.DOTALL)]
if len(keys) != len(vals):
continue
arguments = {}
for k, v in zip(keys, vals):
try:
v = json.loads(v)
except (json.JSONDecodeError, ValueError):
pass # keep as string
arguments[k] = v
if start_pos is None:
start_pos = tc_match.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
return matches, start_pos
def _parse_pythonic_tool_calls(answer: str, tool_names: list[str]):
"""Parse pythonic-style tool calls used by Llama 4 and similar models.
Format:
[func_name(param1="value1", param2="value2"), func_name2(...)]
"""
matches = []
start_pos = None
# Match a bracketed list of function calls
bracket_match = re.search(r'\[([^\[\]]+)\]', answer)
if not bracket_match:
return matches, start_pos
inner = bracket_match.group(1)
# Build pattern for known tool names
escaped_names = [re.escape(name) for name in tool_names]
name_pattern = '|'.join(escaped_names)
for call_match in re.finditer(
r'(' + name_pattern + r')\(([^)]*)\)',
inner
):
func_name = call_match.group(1)
params_str = call_match.group(2).strip()
arguments = {}
if params_str:
# Parse key="value" pairs, handling commas inside quoted values
for param_match in re.finditer(
r'(\w+)\s*=\s*("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\)]+)',
params_str
):
param_name = param_match.group(1)
param_value = param_match.group(2).strip()
# Strip surrounding quotes
if (param_value.startswith('"') and param_value.endswith('"')) or \
(param_value.startswith("'") and param_value.endswith("'")):
param_value = param_value[1:-1]
# Try to parse as JSON for numeric/bool/null values
try:
param_value = json.loads(param_value)
except (json.JSONDecodeError, ValueError):
pass
arguments[param_name] = param_value
if start_pos is None:
start_pos = bracket_match.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
return matches, start_pos
# Format registry: maps template substrings to the parser and streaming
# markers for that format. When a format's hints are NOT found in the
# template, its parser and markers are excluded.
TOOL_CALL_FORMATS = [
{
'template_hints': ['tool▁call▁begin', 'tool▁calls▁begin'],
'parser': _parse_deep_seek_tool_calls,
'markers': ['<tool▁call▁begin>', '<tool▁calls▁begin>'],
},
{
'template_hints': ['<|tool_call_begin|>', 'tool_calls_section'],
'parser': _parse_kimi_tool_calls,
'markers': ['<|tool_call_begin|>', '<|tool_calls_section_begin|>'],
},
{
'template_hints': ['to=functions.', '<|channel|>'],
'parser': _parse_channel_tool_calls,
'markers': ['to=functions.', '<|channel|>commentary'],
},
{
'template_hints': ['minimax:tool_call'],
'parser': _parse_minimax_tool_calls,
'markers': ['<minimax:tool_call>'],
},
{
'template_hints': ['<arg_key>'],
'parser': _parse_glm_tool_calls,
'markers': ['<tool_call>'],
},
{
'template_hints': ['<tool_call>'],
'parser': _parse_xml_param_tool_calls,
'markers': ['<tool_call>'],
},
{
'template_hints': ['[TOOL_CALLS]'],
'parser': _parse_mistral_token_tool_calls,
'markers': ['[TOOL_CALLS]'],
},
{
'template_hints': ['<function_call>'],
'parser': None,
'markers': ['<function_call>'],
},
]
# Default ordered list of all specialized parsers.
ALL_PARSERS = [
_parse_deep_seek_tool_calls,
_parse_kimi_tool_calls,
_parse_channel_tool_calls,
_parse_minimax_tool_calls,
_parse_glm_tool_calls,
_parse_xml_param_tool_calls,
_parse_mistral_token_tool_calls,
_parse_bare_name_tool_calls,
_parse_pythonic_tool_calls,
]
def detect_tool_call_format(template_str):
"""Inspect a chat/instruction template to determine which tool call
formats are relevant.
Uses an exclude-based approach: starts with all parsers/markers,
then removes the ones whose hints are not found in the template.
Returns (parsers, streaming_markers, check_bare_names).
"""
if not template_str:
return None, TOOL_CALL_OPENING_MARKERS, True
matched_any = False
exclude_parsers = []
exclude_markers = []
matched_markers = []
for fmt in TOOL_CALL_FORMATS:
if any(hint in template_str for hint in fmt['template_hints']):
matched_any = True
matched_markers.extend(fmt['markers'])
else:
if fmt['parser'] is not None:
exclude_parsers.append(fmt['parser'])
exclude_markers.extend(fmt['markers'])
if not matched_any:
return None, TOOL_CALL_OPENING_MARKERS, True
parsers = [p for p in ALL_PARSERS if p not in exclude_parsers]
markers = [m for m in TOOL_CALL_OPENING_MARKERS if m not in exclude_markers or m in matched_markers]
return parsers, markers, False
def parse_tool_call(answer: str, tool_names: list[str], return_prefix: bool = False, parsers: list = None):
matches = []
start_pos = None
def _return(matches, start_pos):
if return_prefix:
prefix = answer[:start_pos] if matches and start_pos is not None else ''
return matches, prefix
return matches
# Try specialized parsers.
for parser in (parsers if parsers is not None else ALL_PARSERS):
matches, start_pos = parser(answer, tool_names)
if matches:
return _return(matches, start_pos)
# Generic fallback: regex pattern to find the JSON content wrapped in <function>, <tools>, <tool_call>, and other tags observed from various models
patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)</\1>"]
for pattern in patterns:
for match in re.finditer(pattern, answer, re.DOTALL):
if match.group(2) is None:
continue
# remove backtick wraps if present
candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip())
candidate = re.sub(r"```$", "", candidate.strip())
# unwrap inner tags
candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL)
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
if not candidate.strip().startswith("["):
candidate = "[" + candidate + "]"
candidates = []
try:
# parse the candidate JSON into a dictionary
candidates = json.loads(candidate)
if not isinstance(candidates, list):
candidates = [candidates]
except json.JSONDecodeError:
# Ignore invalid JSON silently
continue
for candidate_dict in candidates:
checked_candidate = check_and_sanitize_tool_call_candidate(candidate_dict, tool_names)
if checked_candidate is not None:
if start_pos is None:
start_pos = match.start()
matches.append(checked_candidate)
# last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags
if len(matches) == 0:
try:
candidate = answer
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
if not candidate.strip().startswith("["):
candidate = "[" + candidate + "]"
# parse the candidate JSON into a dictionary
candidates = json.loads(candidate)
if not isinstance(candidates, list):
candidates = [candidates]
for candidate_dict in candidates:
checked_candidate = check_and_sanitize_tool_call_candidate(candidate_dict, tool_names)
if checked_candidate is not None:
matches.append(checked_candidate)
except json.JSONDecodeError:
# Ignore invalid JSON silently
pass
return _return(matches, start_pos)

View file

@ -1,71 +0,0 @@
import importlib.util
import json
from modules import shared
from modules.logging_colors import logger
from modules.utils import natural_keys, sanitize_filename
def get_available_tools():
"""Return sorted list of tool script names from user_data/tools/*.py."""
tools_dir = shared.user_data_dir / 'tools'
tools_dir.mkdir(parents=True, exist_ok=True)
return sorted((p.stem for p in tools_dir.glob('*.py')), key=natural_keys)
def load_tools(selected_names):
"""
Import selected tool scripts and return their definitions and executors.
Returns (tool_defs, executors) where:
- tool_defs: list of OpenAI-format tool dicts
- executors: dict mapping function_name -> execute callable
"""
tool_defs = []
executors = {}
for name in selected_names:
name = sanitize_filename(name)
if not name:
continue
path = shared.user_data_dir / 'tools' / f'{name}.py'
if not path.exists():
continue
try:
spec = importlib.util.spec_from_file_location(f"tool_{name}", str(path))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
except Exception:
logger.exception(f'Failed to load tool script "{name}"')
continue
tool_def = getattr(module, 'tool', None)
execute_fn = getattr(module, 'execute', None)
if tool_def is None or execute_fn is None:
logger.warning(f'Tool "{name}" is missing a "tool" dict or "execute" function.')
continue
func_name = tool_def.get('function', {}).get('name', name)
if func_name in executors:
logger.warning(f'Tool "{name}" declares function name "{func_name}" which conflicts with an already loaded tool. Skipping.')
continue
tool_defs.append(tool_def)
executors[func_name] = execute_fn
return tool_defs, executors
def execute_tool(func_name, arguments, executors):
"""Execute a tool by function name. Returns result as a JSON string."""
fn = executors.get(func_name)
if fn is None:
return json.dumps({"error": f"Unknown tool: {func_name}"})
try:
if isinstance(arguments, str):
arguments = json.loads(arguments)
result = fn(arguments)
return json.dumps(result) if not isinstance(result, str) else result
except Exception as e:
logger.exception(f'Tool "{func_name}" execution failed')
return json.dumps({"error": str(e)})

View file

@ -12,6 +12,9 @@ def get_device():
return shared.model.device return shared.model.device
elif torch.cuda.is_available(): elif torch.cuda.is_available():
return torch.device('cuda') return torch.device('cuda')
elif shared.args.deepspeed:
import deepspeed
return deepspeed.get_accelerator().current_device_name()
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():
return torch.device('mps') return torch.device('mps')
elif is_torch_xpu_available(): elif is_torch_xpu_available():

View file

@ -14,7 +14,6 @@ import traceback
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
import yaml
import gradio as gr import gradio as gr
from modules import shared, ui, utils from modules import shared, ui, utils
@ -25,8 +24,9 @@ from modules.evaluate import (
) )
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.models import reload_model from modules.models import reload_model
from modules.utils import natural_keys
PARAMETERS = ["lora_name", "always_override", "all_linear", "q_proj_en", "v_proj_en", "k_proj_en", "o_proj_en", "gate_proj_en", "down_proj_en", "up_proj_en", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "text_dataset", "higher_rank_limit", "warmup_steps", "optimizer", "stride_length", "stop_at_loss", "add_eos_token", "excess_length", "report_to"] PARAMETERS = ["lora_name", "always_override", "q_proj_en", "v_proj_en", "k_proj_en", "o_proj_en", "gate_proj_en", "down_proj_en", "up_proj_en", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to"]
WANT_INTERRUPT = False WANT_INTERRUPT = False
train_log = {} train_log = {}
@ -53,8 +53,7 @@ def create_ui():
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).', elem_classes=['no-background']) always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).', elem_classes=['no-background'])
with gr.Accordion(label='Target Modules', open=False, elem_classes='tgw-accordion'): with gr.Accordion(label='Target Modules', open=False, elem_classes='tgw-accordion'):
gr.Markdown("Selects which modules to target in training. Targeting more modules is closer to a full fine-tune at the cost of increased VRAM and adapter size.") gr.Markdown("Selects which modules to target in training. Targeting more modules is closer to a full fine-tune at the cost of increased VRAM requirements and adapter size.\nNOTE: Only works for model_id='llama', other types will retain default training behavior and not use these settings.")
all_linear = gr.Checkbox(label='Target all linear layers', value=True, info='Targets every nn.Linear layer except lm_head. Works for any model architecture. When checked, the individual module checkboxes below are ignored.', elem_classes=['no-background'])
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
q_proj_en = gr.Checkbox(label='Enable q_proj', value=True) q_proj_en = gr.Checkbox(label='Enable q_proj', value=True)
@ -73,60 +72,67 @@ def create_ui():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
lora_rank = gr.Slider(label='LoRA Rank', value=8, minimum=0, maximum=1024, step=4, info='Also called dimension count. Higher values = larger file, more content control. Smaller values = smaller file, less control. Use 4 or 8 for style, 128 or 256 to teach, 1024+ for fine-detail on big data. More VRAM is needed for higher ranks.') lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='Also called dimension count. Higher values = larger file, more content control. Smaller values = smaller file, less control. Use 4 or 8 for style, 128 or 256 to teach, 1024+ for fine-detail on big data. More VRAM is needed for higher ranks.')
lora_alpha = gr.Slider(label='LoRA Alpha', value=16, minimum=0, maximum=2048, step=4, info='This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.') lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
batch_size = gr.Slider(label='Batch Size', value=32, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.') batch_size = gr.Slider(label='Batch Size', value=128, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.')
micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.') micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.')
cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=4096, value=512, step=32, info='Maximum sequence length in tokens. For instruction datasets, conversations longer than this are dropped. For text datasets, documents are split into chunks of this size. Higher values require more VRAM.') cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=4096, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.')
with gr.Column(): with gr.Column():
save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a full training checkpoint (adapter weights, optimizer, scheduler) will be saved every time this many steps pass. Training can be resumed from these checkpoints.') save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a checkpoint of the LoRA will be saved every time this many steps pass.')
epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.') epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.')
learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='In scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.') learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='In scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.')
with gr.Row(): with gr.Row():
lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='cosine', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt'], info='Learning rate scheduler - defines how the learning rate changes over time. "Constant" means never change, "linear" means to go in a straight line from the learning rate down to 0, cosine follows a curve, etc.', elem_classes=['slim-dropdown']) lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='linear', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt'], info='Learning rate scheduler - defines how the learning rate changes over time. "Constant" means never change, "linear" means to go in a straight line from the learning rate down to 0, cosine follows a curve, etc.', elem_classes=['slim-dropdown'])
with gr.Accordion(label='Advanced Options', open=False, elem_classes='tgw-accordion'): with gr.Accordion(label='Advanced Options', open=False, elem_classes='tgw-accordion'):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.0, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.') lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.')
stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)') stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)')
with gr.Row(): with gr.Row():
optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Optimizer algorithm. adamw_torch is the standard choice. adamw_bnb_8bit uses less VRAM. adafactor is memory-efficient for large models.', elem_classes=['slim-dropdown']) optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.', elem_classes=['slim-dropdown'])
with gr.Column(): with gr.Column():
warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate is gradually ramped up from 0 to the target value. This prevents unstable updates early in training.') warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate will be lower than normal. This helps the trainer prepare the model and precompute statistics to improve the quality of training after the start.')
train_only_after = gr.Textbox(label='Train Only After', value='', info='Only consider text *after* this string in any given chunk for training. For Alpaca datasets, use "### Response:" to only train the response and ignore the input.')
add_eos_token = gr.Checkbox(label='Add EOS token', value=True, info="Adds EOS token for each document in text datasets.") add_eos_token = gr.Checkbox(label='Add EOS token', value=False, info="Adds EOS token for each dataset item. In case of raw text, the EOS will be added at the Hard Cut")
excess_length = gr.Dropdown(label='Excess length', value='drop', choices=['drop', 'truncate'], info='What to do with conversations that exceed the cutoff length. "Drop" removes them entirely (recommended). "Truncate" cuts from the right, which may produce incomplete responses.', elem_classes=['slim-dropdown'])
higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.') higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
report_to = gr.Radio(label="Save detailed logs with", value="None", choices=["None", "wandb", "tensorboard"], interactive=True) report_to = gr.Radio(label="Save detailed logs with", value="None", choices=["None", "wandb", "tensorboard"], interactive=True)
with gr.Column(): with gr.Column():
with gr.Tab(label='Chat Dataset'): with gr.Tab(label='Formatted Dataset'):
with gr.Row(): with gr.Row():
dataset = gr.Dropdown(choices=utils.get_chat_datasets(str(shared.user_data_dir / 'training/datasets')), value='None', label='Dataset File', info='A JSON file with chat conversations (messages or ShareGPT format). Each row is one conversation.', elem_classes=['slim-dropdown'], interactive=not mu) format = gr.Dropdown(choices=utils.get_datasets('user_data/training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.', elem_classes=['slim-dropdown'], interactive=not mu)
ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': utils.get_chat_datasets(str(shared.user_data_dir / 'training/datasets'))}, 'refresh-button', interactive=not mu) ui.create_refresh_button(format, lambda: None, lambda: {'choices': utils.get_datasets('user_data/training/formats', 'json')}, 'refresh-button', interactive=not mu)
with gr.Row(): with gr.Row():
format = gr.Dropdown(choices=get_instruction_templates(), value='None', label='Instruction Template', info='Select an instruction template for formatting the dataset, or "Chat Template" to use the model\'s built-in chat template.', elem_classes=['slim-dropdown'], interactive=not mu) dataset = gr.Dropdown(choices=utils.get_datasets('user_data/training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.', elem_classes=['slim-dropdown'], interactive=not mu)
ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_instruction_templates()}, 'refresh-button', interactive=not mu) ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': utils.get_datasets('user_data/training/datasets', 'json')}, 'refresh-button', interactive=not mu)
with gr.Tab(label="Text Dataset"):
with gr.Row():
text_dataset = gr.Dropdown(choices=utils.get_text_datasets(str(shared.user_data_dir / 'training/datasets')), value='None', label='Dataset File', info='A JSON file with a "text" key per row, for pretraining-style training. Each row is one document.', elem_classes=['slim-dropdown'], interactive=not mu)
ui.create_refresh_button(text_dataset, lambda: None, lambda: {'choices': utils.get_text_datasets(str(shared.user_data_dir / 'training/datasets'))}, 'refresh-button', interactive=not mu)
stride_length = gr.Slider(label='Stride Length', minimum=0, maximum=2048, value=256, step=32, info='Overlap between chunks in tokens. 0 = no overlap. Values like 256 or 512 help preserve context across chunk boundaries.')
with gr.Row(): with gr.Row():
eval_dataset = gr.Dropdown(choices=utils.get_datasets(str(shared.user_data_dir / 'training/datasets'), 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.', elem_classes=['slim-dropdown'], interactive=not mu) eval_dataset = gr.Dropdown(choices=utils.get_datasets('user_data/training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.', elem_classes=['slim-dropdown'], interactive=not mu)
ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': utils.get_datasets(str(shared.user_data_dir / 'training/datasets'), 'json')}, 'refresh-button', interactive=not mu) ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': utils.get_datasets('user_data/training/datasets', 'json')}, 'refresh-button', interactive=not mu)
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.') eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
with gr.Tab(label="Raw text file"):
with gr.Row():
raw_text_file = gr.Dropdown(choices=utils.get_datasets('user_data/training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.', elem_classes=['slim-dropdown'], interactive=not mu)
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': utils.get_datasets('user_data/training/datasets', 'txt')}, 'refresh-button', interactive=not mu)
with gr.Row():
with gr.Column():
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='How many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length). Setting overlap to exactly half the cutoff length may be ideal.')
newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
with gr.Column():
hard_cut_string = gr.Textbox(label='Hard Cut String', value='\\n\\n\\n', info='String that indicates a hard cut between text parts. Helps prevent unwanted overlap.')
min_chars = gr.Number(label='Ignore small blocks', value=0, info='Ignore Hard Cut blocks that have less or equal characters than this number')
with gr.Row(): with gr.Row():
start_button = gr.Button("Start LoRA Training", variant='primary', interactive=not mu) start_button = gr.Button("Start LoRA Training", variant='primary', interactive=not mu)
stop_button = gr.Button("Interrupt", interactive=not mu) stop_button = gr.Button("Interrupt", interactive=not mu)
@ -137,7 +143,7 @@ def create_ui():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
models = gr.Dropdown(utils.get_available_models(), label='Models', multiselect=True, interactive=not mu) models = gr.Dropdown(utils.get_available_models(), label='Models', multiselect=True, interactive=not mu)
evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + utils.get_datasets(str(shared.user_data_dir / 'training/datasets'), 'txt')[1:], value='wikitext', label='Input dataset', info=f'The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under {shared.user_data_dir}/training/datasets.', interactive=not mu) evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + utils.get_datasets('user_data/training/datasets', 'txt')[1:], value='wikitext', label='Input dataset', info='The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under user_data/training/datasets.', interactive=not mu)
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
stride_length = gr.Slider(label='Stride', minimum=0, maximum=32768, value=512, step=256, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.') stride_length = gr.Slider(label='Stride', minimum=0, maximum=32768, value=512, step=256, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')
@ -159,7 +165,7 @@ def create_ui():
refresh_table = gr.Button('Refresh the table', elem_classes="small-button", interactive=not mu) refresh_table = gr.Button('Refresh the table', elem_classes="small-button", interactive=not mu)
# Training events # Training events
all_params = [lora_name, always_override, all_linear, q_proj_en, v_proj_en, k_proj_en, o_proj_en, gate_proj_en, down_proj_en, up_proj_en, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, text_dataset, higher_rank_limit, warmup_steps, optimizer, stride_length, stop_at_loss, add_eos_token, excess_length, report_to] all_params = [lora_name, always_override, q_proj_en, v_proj_en, k_proj_en, o_proj_en, gate_proj_en, down_proj_en, up_proj_en, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to]
copy_from.change(do_copy_params, [copy_from] + all_params, all_params) copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
start_button.click(do_train, all_params, output) start_button.click(do_train, all_params, output)
@ -223,34 +229,9 @@ def clean_path(base_path: str, path: str):
return f'{Path(base_path).absolute()}/{path}' return f'{Path(base_path).absolute()}/{path}'
def get_instruction_templates():
path = shared.user_data_dir / 'instruction-templates'
names = set()
for ext in ['yaml', 'yml', 'jinja', 'jinja2']:
for f in path.glob(f'*.{ext}'):
names.add(f.stem)
return ['None', 'Chat Template'] + sorted(names, key=utils.natural_keys)
def load_template(name):
"""Load a Jinja2 template string from {user_data_dir}/instruction-templates/."""
path = shared.user_data_dir / 'instruction-templates'
for ext in ['jinja', 'jinja2', 'yaml', 'yml']:
filepath = path / f'{name}.{ext}'
if filepath.exists():
if ext in ['jinja', 'jinja2']:
return filepath.read_text(encoding='utf-8')
else:
data = yaml.safe_load(filepath.read_text(encoding='utf-8'))
return data.get('instruction_template', '')
return ''
def backup_adapter(input_folder): def backup_adapter(input_folder):
# Get the creation date of the adapter file (safetensors or bin) # Get the creation date of the file adapter_model.bin
try: try:
adapter_file = Path(f"{input_folder}/adapter_model.safetensors")
if not adapter_file.is_file():
adapter_file = Path(f"{input_folder}/adapter_model.bin") adapter_file = Path(f"{input_folder}/adapter_model.bin")
if adapter_file.is_file(): if adapter_file.is_file():
@ -263,7 +244,7 @@ def backup_adapter(input_folder):
subfolder_path.mkdir(parents=True, exist_ok=True) subfolder_path.mkdir(parents=True, exist_ok=True)
# Check if the file already exists in the subfolder # Check if the file already exists in the subfolder
backup_adapter_file = subfolder_path / adapter_file.name backup_adapter_file = Path(f"{input_folder}/{creation_date_str}/adapter_model.bin")
if backup_adapter_file.is_file(): if backup_adapter_file.is_file():
print(" - Backup already exists. Skipping backup process.") print(" - Backup already exists. Skipping backup process.")
return return
@ -293,7 +274,7 @@ def calc_trainable_parameters(model):
return trainable_params, all_param return trainable_params, all_param
def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: bool, v_proj_en: bool, k_proj_en: bool, o_proj_en: bool, gate_proj_en: bool, down_proj_en: bool, up_proj_en: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, text_dataset: str, higher_rank_limit: bool, warmup_steps: int, optimizer: str, stride_length: int, stop_at_loss: float, add_eos_token: bool, excess_length: str, report_to: str): def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en: bool, k_proj_en: bool, o_proj_en: bool, gate_proj_en: bool, down_proj_en: bool, up_proj_en: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str):
import torch import torch
import transformers import transformers
@ -304,17 +285,21 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
prepare_model_for_kbit_training, prepare_model_for_kbit_training,
set_peft_model_state_dict set_peft_model_state_dict
) )
from peft.utils.other import \
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \
model_to_lora_modules
from transformers import is_torch_xpu_available
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
)
MODEL_CLASSES = {v[1]: v[0] for v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.items()}
global WANT_INTERRUPT global WANT_INTERRUPT
WANT_INTERRUPT = False WANT_INTERRUPT = False
# == Input validation / processing == # == Input validation / processing ==
yield "Preparing the input..." yield "Preparing the input..."
if shared.args.loader == 'llama.cpp':
yield "Error: LoRA training requires a model loaded with the Transformers loader. GGUF models are not supported for training."
return
lora_file_path = clean_path(None, lora_name) lora_file_path = clean_path(None, lora_name)
if lora_file_path.strip() == '': if lora_file_path.strip() == '':
yield "Missing or invalid LoRA file name input." yield "Missing or invalid LoRA file name input."
@ -324,6 +309,10 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
actual_lr = float(learning_rate) actual_lr = float(learning_rate)
model_type = type(shared.model).__name__ model_type = type(shared.model).__name__
if model_type in MODEL_CLASSES:
model_id = MODEL_CLASSES[model_type]
else:
model_id = "llama"
if model_type == "PeftModelForCausalLM": if model_type == "PeftModelForCausalLM":
if len(shared.lora_names) > 0: if len(shared.lora_names) > 0:
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
@ -331,6 +320,9 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
else: else:
yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
logger.warning("Model ID not matched due to LoRA loading. Consider reloading base model.") logger.warning("Model ID not matched due to LoRA loading. Consider reloading base model.")
else:
yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
logger.warning(f"LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
time.sleep(5) time.sleep(5)
@ -338,206 +330,166 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
yield "Cannot input zeroes." yield "Cannot input zeroes."
return return
gradient_accumulation_steps = max(1, batch_size // micro_batch_size) gradient_accumulation_steps = batch_size // micro_batch_size
original_chat_template = getattr(shared.tokenizer, 'chat_template', None) shared.tokenizer.pad_token_id = 0
if shared.tokenizer.pad_token_id is None: shared.tokenizer.padding_side = "left"
shared.tokenizer.pad_token_id = shared.tokenizer.eos_token_id
shared.tokenizer.padding_side = "right"
def list_target_modules(): # Populate target_modules list with chosen X_proj modules. Llama-based models only atm, non-llama will revert to default behavior.
if all_linear: def list_target_modules(model_id):
return "all-linear" if model_id != "llama" and model_id != "mistral":
return model_to_lora_modules[model_id]
target_mods = [f"{name}_proj" for name, enabled in { available_modules = {
"q": q_proj_en, "k": k_proj_en, "v": v_proj_en, "o": o_proj_en, "gate": gate_proj_en,
"gate": gate_proj_en, "down": down_proj_en, "up": up_proj_en, "down": down_proj_en,
}.items() if enabled] "up": up_proj_en,
"q": q_proj_en,
"v": v_proj_en,
"k": k_proj_en,
"o": o_proj_en,
}
target_mods = [f"{name}_proj" for name, enabled in available_modules.items() if enabled]
return target_mods return target_mods
def normalize_messages(data_point): def encode(text, add_bos_token):
"""Convert a dataset row to OpenAI messages format for apply_chat_template().""" result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len)
if "messages" in data_point: # Check if the first two tokens are BOS
return data_point["messages"] if len(result) >= 2 and result[:2] == [shared.tokenizer.bos_token_id, shared.tokenizer.bos_token_id]:
result = result[1:]
if "conversations" in data_point: if not add_bos_token and result[0] == shared.tokenizer.bos_token_id:
role_map = {"human": "user", "gpt": "assistant"} result = result[1:]
return [ return result
{"role": role_map.get(turn.get("from", ""), turn.get("from", "")), "content": turn["value"]}
for turn in data_point["conversations"]
]
raise RuntimeError( def tokenize(prompt, append_eos_token=False):
f'Dataset row must contain "messages" or "conversations" key. '
f'Found: {list(data_point.keys())}'
)
def tokenize_conversation(data_point): if train_only_after == '' or train_only_after not in prompt:
"""Tokenize using apply_chat_template() with assistant-only label masking.""" input_ids = encode(prompt, True)
messages = normalize_messages(data_point)
full_ids = list(shared.tokenizer.apply_chat_template(messages, tokenize=True, return_dict=False))
# Build labels: -100 for everything, then unmask assistant turns. if append_eos_token and input_ids[-1] != shared.tokenizer.eos_token_id and len(input_ids) < cutoff_len:
# This assumes apply_chat_template(messages[:i]) is a token-for-token input_ids.append(shared.tokenizer.eos_token_id)
# prefix of apply_chat_template(messages[:i+1]), which holds for all
# standard chat templates (Llama, ChatML, Mistral, etc.). input_ids = [shared.tokenizer.pad_token_id] * (cutoff_len - len(input_ids)) + input_ids
labels = [-100] * len(full_ids) labels = [1] * len(input_ids)
for i, msg in enumerate(messages):
if msg["role"] == "assistant":
# Tokens up to where this assistant turn starts
header_ids = shared.tokenizer.apply_chat_template(
messages[:i], tokenize=True, return_dict=False, add_generation_prompt=True
)
# Tokens through end of this assistant turn
through_ids = shared.tokenizer.apply_chat_template(
messages[:i + 1], tokenize=True, return_dict=False
)
# Unmask assistant tokens
start = len(header_ids)
end = min(len(through_ids), len(full_ids))
labels[start:end] = full_ids[start:end]
if len(full_ids) > cutoff_len:
if excess_length == 'truncate':
full_ids = full_ids[:cutoff_len]
labels = labels[:cutoff_len]
else: else:
return {"input_ids": [], "labels": [], "attention_mask": []} ind = prompt.index(train_only_after) + len(train_only_after)
before_tokens = encode(prompt[:ind], True)
after_tokens = encode(prompt[ind:], False)
if append_eos_token and after_tokens[-1] != shared.tokenizer.eos_token_id:
after_tokens.append(shared.tokenizer.eos_token_id)
full_length = len(after_tokens) + len(before_tokens)
if full_length > cutoff_len:
after_tokens = after_tokens[:cutoff_len - len(before_tokens)]
else:
before_tokens = [shared.tokenizer.pad_token_id] * (cutoff_len - full_length) + before_tokens
input_ids = before_tokens + after_tokens
labels = [-100] * len(before_tokens) + [1] * len(after_tokens)
input_ids = torch.tensor(input_ids)
return { return {
"input_ids": full_ids, "input_ids": input_ids,
"labels": labels, "labels": labels,
"attention_mask": [1] * len(full_ids), "attention_mask": input_ids.ne(shared.tokenizer.pad_token_id),
} }
train_template.clear() train_template.clear()
# == Prep the dataset, format, etc == # == Prep the dataset, format, etc ==
has_text_dataset = text_dataset not in ['None', ''] if raw_text_file not in ['None', '']:
has_chat_dataset = dataset not in ['None', ''] train_template["template_type"] = "raw_text"
if has_text_dataset and has_chat_dataset: logger.info("Loading raw text file dataset")
yield "Error: select either a Chat Dataset or a Text Dataset, not both." fullpath = clean_path('user_data/training/datasets', f'{raw_text_file}')
return fullpath = Path(fullpath)
if fullpath.is_dir():
logger.info('Training path directory {}'.format(raw_text_file))
raw_text = ""
file_paths = sorted(fullpath.glob('*.txt'), key=lambda path: natural_keys(path.name))
for file_path in file_paths:
if file_path.is_file():
with file_path.open('r', encoding='utf-8') as file:
raw_text += file.read().replace('\r', '')
def tokenize_text_data(data): logger.info(f"Loaded training file: {file_path.name}")
"""Tokenize text dataset rows, concatenate, and split into chunks.""" else:
all_tokens = [] with open(clean_path('user_data/training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
for row in data: raw_text = file.read().replace('\r', '')
tokens = shared.tokenizer.encode(row['text'])
cut_string = hard_cut_string.replace('\\n', '\n')
eos_added = 0
out_tokens = []
for text_part in raw_text.split(cut_string):
if len(text_part.strip()) <= min_chars:
continue
tokens = shared.tokenizer.encode(text_part)
if add_eos_token: if add_eos_token:
tokens.append(shared.tokenizer.eos_token_id) tokens.append(shared.tokenizer.eos_token_id)
all_tokens.extend(tokens) eos_added += 1
stride = int(stride_length)
step = cutoff_len - stride if stride > 0 else cutoff_len
step = cutoff_len - overlap_len
if step <= 0: if step <= 0:
return None, "Error: stride length must be smaller than cutoff length." yield f"Error: overlap_len ({overlap_len}) cannot be greater than or equal to cutoff_len ({cutoff_len})"
if len(all_tokens) < cutoff_len:
return None, "Error: dataset is too short to fill even one chunk of the given cutoff length."
chunks = []
for start in range(0, len(all_tokens), step):
chunk = all_tokens[start:start + cutoff_len]
if len(chunk) == 0:
break
if len(chunk) < cutoff_len:
pad_len = cutoff_len - len(chunk)
chunks.append({
"input_ids": chunk + [shared.tokenizer.pad_token_id] * pad_len,
"labels": list(chunk) + [-100] * pad_len,
"attention_mask": [1] * len(chunk) + [0] * pad_len,
})
else:
chunks.append({
"input_ids": chunk,
"labels": list(chunk),
"attention_mask": [1] * cutoff_len,
})
return Dataset.from_list(chunks), None
if has_text_dataset:
train_template["template_type"] = "text_dataset"
logger.info("Loading text dataset")
data = load_dataset("json", data_files=clean_path(str(shared.user_data_dir / 'training/datasets'), f'{text_dataset}.json'))
if "text" not in data['train'].column_names:
yield "Error: text dataset must have a \"text\" key per row."
return return
train_data, err = tokenize_text_data(data['train']) out_tokens.extend(split_chunks(tokens, cutoff_len, step))
if err:
yield err
return
if eval_dataset == 'None': if eos_added > 0:
print(f"EOS added to {eos_added} text blocks")
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
text_chunks = [shared.tokenizer.decode(x) for x in out_tokens]
del out_tokens
if newline_favor_len > 0:
text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks]
train_data = Dataset.from_list([tokenize(x) for x in text_chunks])
del text_chunks
eval_data = None eval_data = None
else: else:
eval_raw = load_dataset("json", data_files=clean_path(str(shared.user_data_dir / 'training/datasets'), f'{eval_dataset}.json')) if dataset in ['None', '']:
if "text" not in eval_raw['train'].column_names: yield "Missing dataset choice input, cannot continue."
yield "Error: evaluation dataset must have a \"text\" key per row."
return return
eval_data, err = tokenize_text_data(eval_raw['train'])
if err:
yield err
return
elif has_chat_dataset:
if format in ['None', '']: if format in ['None', '']:
yield "Missing format choice input, cannot continue." yield "Missing format choice input, cannot continue."
return return
if format == 'Chat Template': train_template["template_type"] = "dataset"
if not getattr(shared.tokenizer, 'chat_template', None):
yield "Error: this model's tokenizer does not have a chat template. Select an instruction template instead, or load an instruct/chat model."
return
else:
# Load custom instruction template and set on tokenizer
template_str = load_template(format)
if not template_str:
yield f"Error: could not load instruction template '{format}'."
return
shared.tokenizer.chat_template = template_str
# Unified path — both cases use tokenize_conversation() with open(clean_path('user_data/training/formats', f'{format}.json'), 'r', encoding='utf-8-sig') as formatFile:
train_template["template_type"] = "chat_template" format_data: dict[str, str] = json.load(formatFile)
logger.info("Loading JSON dataset with chat template format") # == store training prompt ==
data = load_dataset("json", data_files=clean_path(str(shared.user_data_dir / 'training/datasets'), f'{dataset}.json')) for _, value in format_data.items():
prompt_key = f"template_{len(train_template)}"
train_template[prompt_key] = value
# Validate the first row def generate_prompt(data_point: dict[str, str]):
try: for options, data in format_data.items():
normalize_messages(data['train'][0]) if set(options.split(',')) == set(x[0] for x in data_point.items() if (type(x[1]) is str and len(x[1].strip()) > 0)):
except (RuntimeError, KeyError, IndexError) as e: for key, val in data_point.items():
yield f"Error: {e}" if type(val) is str:
return data = data.replace(f'%{key}%', val)
return data
raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
total = len(data['train']) def generate_and_tokenize_prompt(data_point):
train_data = data['train'].map( prompt = generate_prompt(data_point)
tokenize_conversation, return tokenize(prompt, add_eos_token)
remove_columns=data['train'].column_names,
new_fingerprint='%030x' % random.randrange(16**30) logger.info("Loading JSON datasets")
) data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{dataset}.json'))
train_data = train_data.filter(lambda x: len(x['input_ids']) > 0) train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
dropped = total - len(train_data)
if dropped > 0:
logger.warning(f"Dropped {dropped}/{total} conversations exceeding cutoff length of {cutoff_len} tokens.")
if len(train_data) == 0:
yield f"Error: all {total} conversations exceed the cutoff length of {cutoff_len} tokens. Increase the cutoff length or shorten your data."
return
if eval_dataset == 'None': if eval_dataset == 'None':
eval_data = None eval_data = None
else: else:
eval_data = load_dataset("json", data_files=clean_path(str(shared.user_data_dir / 'training/datasets'), f'{eval_dataset}.json')) eval_data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{eval_dataset}.json'))
eval_data = eval_data['train'].map( eval_data = eval_data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
tokenize_conversation,
remove_columns=eval_data['train'].column_names,
new_fingerprint='%030x' % random.randrange(16**30)
)
eval_data = eval_data.filter(lambda x: len(x['input_ids']) > 0)
else:
yield "No dataset selected. Choose a Chat Dataset or a Text Dataset."
return
# == We MUST reload model if it went through any previous training, even failed one == # == We MUST reload model if it went through any previous training, even failed one ==
if shared.model_dirty_from_training: if shared.model_dirty_from_training:
@ -550,14 +502,12 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
if shared.model is not None: if shared.model is not None:
print("Model reloaded OK, continue with training.") print("Model reloaded OK, continue with training.")
else: else:
yield f"Failed to load {selected_model}." return f"Failed to load {selected_model}."
return except:
except Exception:
exc = traceback.format_exc() exc = traceback.format_exc()
logger.error('Failed to reload the model.') logger.error('Failed to reload the model.')
print(exc) print(exc)
yield exc.replace('\n', '\n\n') return exc.replace('\n', '\n\n')
return
# == Start prepping the model itself == # == Start prepping the model itself ==
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'): if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
@ -569,15 +519,10 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
shared.model_dirty_from_training = True shared.model_dirty_from_training = True
logger.info("Preparing for training") logger.info("Preparing for training")
target_modules = list_target_modules()
if not target_modules:
yield "No target modules selected. Enable at least one module or check 'Target all linear layers'."
return
config = LoraConfig( config = LoraConfig(
r=lora_rank, r=lora_rank,
lora_alpha=lora_alpha, lora_alpha=lora_alpha,
target_modules=target_modules, target_modules=list_target_modules(model_id),
lora_dropout=lora_dropout, lora_dropout=lora_dropout,
bias="none", bias="none",
task_type="CAUSAL_LM" task_type="CAUSAL_LM"
@ -590,31 +535,14 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
# == get model trainable params # == get model trainable params
model_trainable_params, model_all_params = calc_trainable_parameters(shared.model) model_trainable_params, model_all_params = calc_trainable_parameters(shared.model)
# == Determine if we can resume from a checkpoint ==
resume_checkpoint = None
try: try:
logger.info("Creating LoRA model") logger.info("Creating LoRA model")
lora_model = get_peft_model(shared.model, config) lora_model = get_peft_model(shared.model, config)
if not always_override and Path(lora_file_path).exists(): if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
# Look for HF Trainer checkpoint dirs (full resumption) logger.info("Loading existing LoRA data")
checkpoints = sorted(Path(lora_file_path).glob("checkpoint-*"), key=os.path.getmtime) state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin", weights_only=True)
if checkpoints:
resume_checkpoint = str(checkpoints[-1])
logger.info(f"Will resume from checkpoint: {resume_checkpoint}")
else:
# Legacy fallback: load bare adapter weights only
safetensors_path = Path(f"{lora_file_path}/adapter_model.safetensors")
bin_path = Path(f"{lora_file_path}/adapter_model.bin")
if safetensors_path.is_file():
logger.info("Loading existing LoRA data (safetensors)")
from safetensors.torch import load_file
state_dict_peft = load_file(str(safetensors_path))
set_peft_model_state_dict(lora_model, state_dict_peft) set_peft_model_state_dict(lora_model, state_dict_peft)
elif bin_path.is_file(): except:
logger.info("Loading existing LoRA data (bin)")
state_dict_peft = torch.load(str(bin_path), weights_only=True)
set_peft_model_state_dict(lora_model, state_dict_peft)
except Exception:
yield traceback.format_exc().replace('\n', '\n\n') yield traceback.format_exc().replace('\n', '\n\n')
return return
@ -634,6 +562,14 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
if WANT_INTERRUPT: if WANT_INTERRUPT:
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
elif state.global_step > 0 and actual_save_steps > 0 and state.global_step % actual_save_steps == 0:
lora_model.save_pretrained(f"{lora_file_path}/checkpoint-{tracked.current_steps}/")
# Save log
with open(f"{lora_file_path}/checkpoint-{tracked.current_steps}/training_log.json", 'w', encoding='utf-8') as file:
json.dump(train_log, file, indent=2)
# == Save training prompt ==
with open(f"{lora_file_path}/checkpoint-{tracked.current_steps}/training_prompt.json", 'w', encoding='utf-8') as file:
json.dump(train_template, file, indent=2)
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs): def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
tracked.current_steps += 1 tracked.current_steps += 1
@ -650,46 +586,22 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
print(f"\033[1;30;40mStep: {tracked.current_steps} \033[0;37;0m", end='') print(f"\033[1;30;40mStep: {tracked.current_steps} \033[0;37;0m", end='')
if 'loss' in logs: if 'loss' in logs:
loss = float(logs['loss']) loss = float(logs['loss'])
if stop_at_loss > 0 and loss <= stop_at_loss: if loss <= stop_at_loss:
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
print(f"\033[1;31;1mStop Loss {stop_at_loss} reached.\033[0;37;0m") print(f"\033[1;31;1mStop Loss {stop_at_loss} reached.\033[0;37;0m")
def on_save(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
checkpoint_dir = Path(args.output_dir) / f"checkpoint-{state.global_step}"
if checkpoint_dir.exists():
with open(checkpoint_dir / "training_log.json", 'w', encoding='utf-8') as file:
json.dump(train_log, file, indent=2)
with open(checkpoint_dir / "training_prompt.json", 'w', encoding='utf-8') as file:
json.dump(train_template, file, indent=2)
# Fix training for mixed precision models # Fix training for mixed precision models
for param in shared.model.parameters(): for param in shared.model.parameters():
if param.requires_grad: if param.requires_grad:
param.data = param.data.float() param.data = param.data.float()
lora_model.config.use_cache = False
def collate_fn(batch):
max_len = max(len(item['input_ids']) for item in batch)
input_ids, labels, attention_mask = [], [], []
for item in batch:
pad_len = max_len - len(item['input_ids'])
input_ids.append(item['input_ids'] + [shared.tokenizer.pad_token_id] * pad_len)
labels.append(item['labels'] + [-100] * pad_len)
attention_mask.append(item['attention_mask'] + [0] * pad_len)
return {
'input_ids': torch.tensor(input_ids),
'labels': torch.tensor(labels),
'attention_mask': torch.tensor(attention_mask),
}
trainer = transformers.Trainer( trainer = transformers.Trainer(
model=lora_model, model=lora_model,
train_dataset=train_data, train_dataset=train_data,
eval_dataset=eval_data, eval_dataset=eval_data,
args=transformers.TrainingArguments( args=transformers.TrainingArguments(
report_to=report_to if report_to != "None" else "none", report_to=report_to if report_to != "None" else None,
per_device_train_batch_size=micro_batch_size, per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps, gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=math.ceil(warmup_steps / gradient_accumulation_steps), warmup_steps=math.ceil(warmup_steps / gradient_accumulation_steps),
@ -698,27 +610,31 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
fp16=False if shared.args.cpu or shared.args.bf16 else True, fp16=False if shared.args.cpu or shared.args.bf16 else True,
bf16=shared.args.bf16, bf16=shared.args.bf16,
optim=optimizer, optim=optimizer,
logging_steps=1, logging_steps=2 if stop_at_loss > 0 else 5,
eval_strategy="steps" if eval_data is not None else "no", eval_strategy="steps" if eval_data is not None else "no",
eval_steps=math.ceil(eval_steps / gradient_accumulation_steps) if eval_data is not None else None, eval_steps=math.ceil(eval_steps / gradient_accumulation_steps) if eval_data is not None else None,
save_strategy="steps" if save_steps > 0 or eval_data is not None else "no", save_strategy="steps" if eval_data is not None else "no",
save_steps=actual_save_steps if save_steps > 0 else None,
output_dir=lora_file_path, output_dir=lora_file_path,
lr_scheduler_type=lr_scheduler_type, lr_scheduler_type=lr_scheduler_type,
load_best_model_at_end=eval_data is not None, load_best_model_at_end=eval_data is not None,
# TODO: Enable multi-device support # TODO: Enable multi-device support
ddp_find_unused_parameters=None, ddp_find_unused_parameters=None,
use_cpu=shared.args.cpu, no_cuda=shared.args.cpu,
remove_unused_columns=False, # use_ipex=True if is_torch_xpu_available() and not shared.args.cpu else False
), ),
data_collator=collate_fn, data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
callbacks=[Callbacks()] callbacks=list([Callbacks()])
) )
lora_model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
lora_model = torch.compile(lora_model)
# == Save parameters for reuse == # == Save parameters for reuse ==
with open(f"{lora_file_path}/training_parameters.json", 'w', encoding='utf-8') as file: with open(f"{lora_file_path}/training_parameters.json", 'w', encoding='utf-8') as file:
local_vars = locals() vars = locals()
json.dump({x: local_vars[x] for x in PARAMETERS}, file, indent=2) json.dump({x: vars[x] for x in PARAMETERS}, file, indent=2)
# == Save training prompt == # == Save training prompt ==
with open(f"{lora_file_path}/training_prompt.json", 'w', encoding='utf-8') as file: with open(f"{lora_file_path}/training_prompt.json", 'w', encoding='utf-8') as file:
@ -730,12 +646,9 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model) lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
if target_modules == "all-linear": projections_string = ", ".join([projection.replace("_proj", "") for projection in list_target_modules(model_id)])
projections_string = "all-linear"
else:
projections_string = ", ".join([projection.replace("_proj", "") for projection in target_modules])
print(f"Training '{model_type}' model using ({projections_string}) projections") print(f"Training '{model_id}' model using ({projections_string}) projections")
if lora_all_param > 0: if lora_all_param > 0:
print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})") print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})")
@ -763,31 +676,23 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
decoded_entries.append({"value": decoded_text}) decoded_entries.append({"value": decoded_text})
# Write the log file # Write the log file
(shared.user_data_dir / 'logs').mkdir(exist_ok=True) Path('user_data/logs').mkdir(exist_ok=True)
with open(shared.user_data_dir / 'logs' / 'train_dataset_sample.json', 'w') as json_file: with open(Path('user_data/logs/train_dataset_sample.json'), 'w') as json_file:
json.dump(decoded_entries, json_file, indent=4) json.dump(decoded_entries, json_file, indent=4)
logger.info(f"Log file 'train_dataset_sample.json' created in the '{shared.user_data_dir}/logs' directory.") logger.info("Log file 'train_dataset_sample.json' created in the 'user_data/logs' directory.")
except Exception as e: except Exception as e:
logger.error(f"Failed to create log file due to error: {e}") logger.error(f"Failed to create log file due to error: {e}")
thread_error = None
def threaded_run(): def threaded_run():
nonlocal thread_error
try:
log_train_dataset(trainer) log_train_dataset(trainer)
trainer.train(resume_from_checkpoint=resume_checkpoint) trainer.train()
# Note: save in the thread in case the gradio thread breaks (eg browser closed) # Note: save in the thread in case the gradio thread breaks (eg browser closed)
lora_model.save_pretrained(lora_file_path) lora_model.save_pretrained(lora_file_path)
tracked.did_save = True
logger.info("LoRA training run is completed and saved.") logger.info("LoRA training run is completed and saved.")
# Save log # Save log
with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file: with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file:
json.dump(train_log, file, indent=2) json.dump(train_log, file, indent=2)
except Exception as e:
thread_error = e
logger.error(f"Training error: {e}")
thread = threading.Thread(target=threaded_run) thread = threading.Thread(target=threaded_run)
thread.start() thread.start()
@ -816,20 +721,11 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining" yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
# Check for errors from the training thread
if thread_error is not None:
yield f"Training failed: {thread_error}"
return
# Saving in the train thread might fail if an error occurs, so save here if so. # Saving in the train thread might fail if an error occurs, so save here if so.
if not tracked.did_save: if not tracked.did_save:
logger.info("Training complete, saving") logger.info("Training complete, saving")
lora_model.save_pretrained(lora_file_path) lora_model.save_pretrained(lora_file_path)
# Restore the original chat_template if we changed it for training
if shared.tokenizer is not None and hasattr(shared.tokenizer, 'chat_template'):
shared.tokenizer.chat_template = original_chat_template
if WANT_INTERRUPT: if WANT_INTERRUPT:
logger.info("Training interrupted.") logger.info("Training interrupted.")
yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`." yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`."
@ -838,6 +734,29 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
yield f"Done! LoRA saved to `{lora_file_path}`.\n\nBefore testing your new LoRA, make sure to first reload the model, as it is currently dirty from training." yield f"Done! LoRA saved to `{lora_file_path}`.\n\nBefore testing your new LoRA, make sure to first reload the model, as it is currently dirty from training."
def split_chunks(arr, size, step):
for i in range(0, len(arr), step):
yield arr[i:i + size]
def cut_chunk_for_newline(chunk: str, max_length: int):
if '\n' not in chunk:
return chunk
first_newline = chunk.index('\n')
if first_newline < max_length:
chunk = chunk[first_newline + 1:]
if '\n' not in chunk:
return chunk
last_newline = chunk.rindex('\n')
if len(chunk) - last_newline < max_length:
chunk = chunk[:last_newline]
return chunk
def format_time(seconds: float): def format_time(seconds: float):
if seconds < 120: if seconds < 120:
return f"`{seconds:.0f}` seconds" return f"`{seconds:.0f}` seconds"

View file

@ -1,3 +1,4 @@
import os
import pprint import pprint
from pathlib import Path from pathlib import Path
@ -5,7 +6,11 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import transformers import transformers
from accelerate import infer_auto_device_map, init_empty_weights from accelerate import infer_auto_device_map, init_empty_weights
from accelerate.utils import is_xpu_available from accelerate.utils import (
is_ccl_available,
is_npu_available,
is_xpu_available
)
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModel, AutoModel,
@ -23,6 +28,31 @@ from modules.torch_utils import get_device
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
local_rank = None
if shared.args.deepspeed:
import deepspeed
from transformers.integrations.deepspeed import (
HfDeepSpeedConfig,
is_deepspeed_zero3_enabled
)
from modules.deepspeed_parameters import generate_ds_config
# Distributed setup
local_rank = shared.args.local_rank if shared.args.local_rank is not None else int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if is_xpu_available() and is_ccl_available():
torch.xpu.set_device(local_rank)
deepspeed.init_distributed(backend="ccl")
elif is_npu_available():
torch.npu.set_device(local_rank)
deepspeed.init_distributed(dist_backend="hccl")
else:
torch.cuda.set_device(local_rank)
deepspeed.init_distributed()
ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
class _StopEverythingStoppingCriteria(transformers.StoppingCriteria): class _StopEverythingStoppingCriteria(transformers.StoppingCriteria):
def __init__(self): def __init__(self):
@ -65,16 +95,14 @@ class LogprobProcessor(LogitsProcessor):
def __init__(self, logprobs=None): def __init__(self, logprobs=None):
self.logprobs = logprobs self.logprobs = logprobs
self.token_alternatives = {} self.token_alternatives = {}
self.token_alternatives_history = []
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
if self.logprobs is not None: # 0-5 if self.logprobs is not None: # 0-5
log_e_probabilities = F.log_softmax(logits, dim=1) log_e_probabilities = F.log_softmax(logits, dim=1)
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs) top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]] top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]]
top_probs = [float(x) for x in top_values[0]] top_probs = [float(x) for x in top_values[0]]
self.token_alternatives = dict(zip(top_tokens, top_probs)) self.token_alternatives = dict(zip(top_tokens, top_probs))
self.token_alternatives_history.append(self.token_alternatives)
return logits return logits
@ -135,7 +163,10 @@ def load_model_HF(model_name):
shared.args.load_in_8bit, shared.args.load_in_8bit,
shared.args.load_in_4bit, shared.args.load_in_4bit,
shared.args.disk, shared.args.disk,
shared.args.deepspeed,
shared.args.cpu_memory is not None, shared.args.cpu_memory is not None,
shared.args.compress_pos_emb > 1,
shared.args.alpha_value > 1,
]) ])
# Load the model without any special settings # Load the model without any special settings
@ -152,6 +183,25 @@ def load_model_HF(model_name):
if device: if device:
model = model.to(device) model = model.to(device)
# DeepSpeed ZeRO-3
elif shared.args.deepspeed:
model = LoaderClass.from_pretrained(
path_to_model,
torch_dtype=params['torch_dtype'],
trust_remote_code=params.get('trust_remote_code')
)
model = deepspeed.initialize(
model=model,
config_params=ds_config,
model_parameters=None,
optimizer=None,
lr_scheduler=None
)[0]
model.module.eval() # Inference
logger.info(f'DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}')
# Load with quantization and/or offloading # Load with quantization and/or offloading
else: else:
if not any((shared.args.cpu, torch.cuda.is_available(), is_xpu_available(), torch.backends.mps.is_available())): if not any((shared.args.cpu, torch.cuda.is_available(), is_xpu_available(), torch.backends.mps.is_available())):
@ -198,6 +248,11 @@ def load_model_HF(model_name):
if shared.args.disk: if shared.args.disk:
params['offload_folder'] = str(Path(shared.args.disk_cache_dir)) params['offload_folder'] = str(Path(shared.args.disk_cache_dir))
if shared.args.compress_pos_emb > 1:
params['rope_scaling'] = {'type': 'linear', 'factor': shared.args.compress_pos_emb}
elif shared.args.alpha_value > 1:
params['rope_scaling'] = {'type': 'dynamic', 'factor': shared.args.alpha_value}
logger.info("TRANSFORMERS_PARAMS=") logger.info("TRANSFORMERS_PARAMS=")
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params) pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params)
print() print()

View file

@ -113,15 +113,65 @@ if not shared.args.old_colors:
block_radius='0', block_radius='0',
) )
if (shared.user_data_dir / "notification.mp3").exists(): if Path("user_data/notification.mp3").exists():
audio_notification_js = "document.querySelector('#audio_notification audio')?.play();" audio_notification_js = "document.querySelector('#audio_notification audio')?.play();"
else: else:
audio_notification_js = "" audio_notification_js = ""
def list_model_elements(): def list_model_elements():
from modules.loaders import list_model_elements elements = [
return list_model_elements() 'filter_by_loader',
'loader',
'cpu_memory',
'gpu_layers',
'cpu_moe',
'threads',
'threads_batch',
'batch_size',
'ubatch_size',
'ctx_size',
'cache_type',
'tensor_split',
'extra_flags',
'streaming_llm',
'gpu_split',
'alpha_value',
'rope_freq_base',
'compress_pos_emb',
'compute_dtype',
'quant_type',
'num_experts_per_token',
'load_in_8bit',
'load_in_4bit',
'attn_implementation',
'cpu',
'disk',
'row_split',
'no_kv_offload',
'no_mmap',
'mlock',
'numa',
'use_double_quant',
'bf16',
'autosplit',
'enable_tp',
'tp_backend',
'no_flash_attn',
'no_xformers',
'no_sdpa',
'cfg_cache',
'cpp_runner',
'no_use_fast',
'model_draft',
'draft_max',
'gpu_layers_draft',
'device_draft',
'ctx_size_draft',
'mmproj',
]
return elements
def list_interface_input_elements(): def list_interface_input_elements():
@ -143,8 +193,6 @@ def list_interface_input_elements():
'tfs', 'tfs',
'top_a', 'top_a',
'top_n_sigma', 'top_n_sigma',
'adaptive_target',
'adaptive_decay',
'dry_multiplier', 'dry_multiplier',
'dry_allowed_length', 'dry_allowed_length',
'dry_base', 'dry_base',
@ -199,12 +247,10 @@ def list_interface_input_elements():
'unique_id', 'unique_id',
'textbox', 'textbox',
'start_with', 'start_with',
'selected_tools',
'mode', 'mode',
'chat_style', 'chat_style',
'chat-instruct_command', 'chat-instruct_command',
'character_menu', 'character_menu',
'user_menu',
'name2', 'name2',
'context', 'context',
'greeting', 'greeting',
@ -304,16 +350,10 @@ def save_settings(state, preset, extensions_list, show_controls, theme_state, ma
if k in shared.settings and k not in exclude: if k in shared.settings and k not in exclude:
output[k] = state[k] output[k] = state[k]
if preset:
output['preset'] = preset output['preset'] = preset
output['prompt-notebook'] = state['prompt_menu-default'] if state['show_two_notebook_columns'] else state['prompt_menu-notebook'] output['prompt-notebook'] = state['prompt_menu-default'] if state['show_two_notebook_columns'] else state['prompt_menu-notebook']
if state.get('character_menu'):
output['character'] = state['character_menu'] output['character'] = state['character_menu']
if state.get('user_menu'):
output['user'] = state['user_menu']
output['seed'] = int(output['seed']) output['seed'] = int(output['seed'])
output['custom_stopping_strings'] = output.get('custom_stopping_strings') or ''
output['custom_token_bans'] = output.get('custom_token_bans') or ''
output['show_controls'] = show_controls output['show_controls'] = show_controls
output['dark_theme'] = True if theme_state == 'dark' else False output['dark_theme'] = True if theme_state == 'dark' else False
output.pop('instruction_template_str') output.pop('instruction_template_str')
@ -337,7 +377,7 @@ def save_settings(state, preset, extensions_list, show_controls, theme_state, ma
output[_id] = params[param] output[_id] = params[param]
else: else:
# Preserve existing extensions and extension parameters during autosave # Preserve existing extensions and extension parameters during autosave
settings_path = shared.user_data_dir / 'settings.yaml' settings_path = Path('user_data') / 'settings.yaml'
if settings_path.exists(): if settings_path.exists():
try: try:
with open(settings_path, 'r', encoding='utf-8') as f: with open(settings_path, 'r', encoding='utf-8') as f:
@ -392,7 +432,7 @@ def _perform_debounced_save():
try: try:
if _last_interface_state is not None: if _last_interface_state is not None:
contents = save_settings(_last_interface_state, _last_preset, _last_extensions, _last_show_controls, _last_theme_state, manual_save=False) contents = save_settings(_last_interface_state, _last_preset, _last_extensions, _last_show_controls, _last_theme_state, manual_save=False)
settings_path = shared.user_data_dir / 'settings.yaml' settings_path = Path('user_data') / 'settings.yaml'
settings_path.parent.mkdir(exist_ok=True) settings_path.parent.mkdir(exist_ok=True)
with open(settings_path, 'w', encoding='utf-8') as f: with open(settings_path, 'w', encoding='utf-8') as f:
f.write(contents) f.write(contents)
@ -417,7 +457,6 @@ def setup_auto_save():
'chat_style', 'chat_style',
'chat-instruct_command', 'chat-instruct_command',
'character_menu', 'character_menu',
'user_menu',
'name1', 'name1',
'name2', 'name2',
'context', 'context',
@ -425,7 +464,6 @@ def setup_auto_save():
'user_bio', 'user_bio',
'custom_system_message', 'custom_system_message',
'chat_template_str', 'chat_template_str',
'selected_tools',
# Parameters tab (ui_parameters.py) - Generation parameters # Parameters tab (ui_parameters.py) - Generation parameters
'preset_menu', 'preset_menu',
@ -446,8 +484,6 @@ def setup_auto_save():
'tfs', 'tfs',
'top_a', 'top_a',
'top_n_sigma', 'top_n_sigma',
'adaptive_target',
'adaptive_decay',
'dry_multiplier', 'dry_multiplier',
'dry_allowed_length', 'dry_allowed_length',
'dry_base', 'dry_base',
@ -476,6 +512,7 @@ def setup_auto_save():
'skip_special_tokens', 'skip_special_tokens',
'stream', 'stream',
'static_cache', 'static_cache',
'truncation_length',
'seed', 'seed',
'sampler_priority', 'sampler_priority',
'custom_stopping_strings', 'custom_stopping_strings',

View file

@ -19,7 +19,7 @@ def create_ui():
shared.gradio['Chat input'] = gr.State() shared.gradio['Chat input'] = gr.State()
shared.gradio['history'] = gr.State({'internal': [], 'visible': [], 'metadata': {}}) shared.gradio['history'] = gr.State({'internal': [], 'visible': [], 'metadata': {}})
shared.gradio['display'] = gr.Headless(value={}) shared.gradio['display'] = gr.JSON(value={}, visible=False) # Hidden buffer
with gr.Tab('Chat', elem_id='chat-tab'): with gr.Tab('Chat', elem_id='chat-tab'):
with gr.Row(elem_id='past-chats-row', elem_classes=['pretty_scrollbar']): with gr.Row(elem_id='past-chats-row', elem_classes=['pretty_scrollbar']):
@ -28,8 +28,7 @@ def create_ui():
shared.gradio['branch_chat'] = gr.Button('Branch', elem_classes=['refresh-button', 'refresh-button-medium'], elem_id='Branch', interactive=not mu) shared.gradio['branch_chat'] = gr.Button('Branch', elem_classes=['refresh-button', 'refresh-button-medium'], elem_id='Branch', interactive=not mu)
shared.gradio['rename_chat'] = gr.Button('Rename', elem_classes=['refresh-button', 'refresh-button-medium'], interactive=not mu) shared.gradio['rename_chat'] = gr.Button('Rename', elem_classes=['refresh-button', 'refresh-button-medium'], interactive=not mu)
shared.gradio['delete_chat'] = gr.Button('🗑️', visible=False, elem_classes='refresh-button', interactive=not mu, elem_id='delete_chat') shared.gradio['delete_chat'] = gr.Button('🗑️', visible=False, elem_classes='refresh-button', interactive=not mu, elem_id='delete_chat')
shared.gradio['Start new chat'] = gr.Button('New chat', elem_classes=['refresh-button', 'refresh-button-medium', 'focus-on-chat-input'], elem_id='new-chat-btn') shared.gradio['Start new chat'] = gr.Button('New chat', elem_classes=['refresh-button', 'refresh-button-medium', 'focus-on-chat-input'])
shared.gradio['Start incognito chat'] = gr.Button('Incognito chat', visible=False, elem_id='incognito-chat-btn')
shared.gradio['branch_index'] = gr.Number(value=-1, precision=0, visible=False, elem_id="Branch-index", interactive=True) shared.gradio['branch_index'] = gr.Number(value=-1, precision=0, visible=False, elem_id="Branch-index", interactive=True)
shared.gradio['search_chat'] = gr.Textbox(placeholder='Search chats...', max_lines=1, elem_id='search_chat') shared.gradio['search_chat'] = gr.Textbox(placeholder='Search chats...', max_lines=1, elem_id='search_chat')
@ -92,21 +91,6 @@ def create_ui():
gr.HTML("<div class='sidebar-vertical-separator'></div>") gr.HTML("<div class='sidebar-vertical-separator'></div>")
from modules.tool_use import get_available_tools
shared.gradio['selected_tools'] = gr.CheckboxGroup(choices=get_available_tools(), value=shared.settings.get('selected_tools', []), label='Tools', info='Functions the model can call during generation.', elem_id='tools-group')
shared.gradio['tools_refresh'] = gr.Button('Refresh list', elem_id='tools-refresh-btn', visible=False)
shared.gradio['tools_refresh'].click(fn=lambda: gr.update(choices=get_available_tools()), inputs=[], outputs=[shared.gradio['selected_tools']])
def sync_web_tools(selected):
if 'web_search' in selected and 'fetch_webpage' not in selected and 'fetch_webpage' in get_available_tools():
selected.append('fetch_webpage')
return gr.update(value=selected)
shared.gradio['selected_tools'].change(fn=sync_web_tools, inputs=[shared.gradio['selected_tools']], outputs=[shared.gradio['selected_tools']], show_progress=False)
gr.HTML("<div class='sidebar-vertical-separator'></div>")
with gr.Row(): with gr.Row():
shared.gradio['mode'] = gr.Radio(choices=['instruct', 'chat-instruct', 'chat'], value=None, label='Mode', info='In instruct and chat-instruct modes, the template under Parameters > Instruction template is used.', elem_id='chat-mode') shared.gradio['mode'] = gr.Radio(choices=['instruct', 'chat-instruct', 'chat'], value=None, label='Mode', info='In instruct and chat-instruct modes, the template under Parameters > Instruction template is used.', elem_id='chat-mode')
@ -153,12 +137,6 @@ def create_character_settings_ui():
shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=5, label='Greeting', elem_classes=['add_scrollbar'], elem_id="character-greeting") shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=5, label='Greeting', elem_classes=['add_scrollbar'], elem_id="character-greeting")
with gr.Tab("User"): with gr.Tab("User"):
with gr.Row():
shared.gradio['user_menu'] = gr.Dropdown(value=shared.settings['user'], choices=utils.get_available_users(), label='User', elem_id='user-menu', info='Select a user profile.', elem_classes='slim-dropdown')
ui.create_refresh_button(shared.gradio['user_menu'], lambda: None, lambda: {'choices': utils.get_available_users()}, 'refresh-button', interactive=not mu)
shared.gradio['save_user'] = gr.Button('💾', elem_classes='refresh-button', elem_id="save-user", interactive=not mu)
shared.gradio['delete_user'] = gr.Button('🗑️', elem_classes='refresh-button', interactive=not mu)
shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Name') shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Name')
shared.gradio['user_bio'] = gr.Textbox(value=shared.settings['user_bio'], lines=10, label='Description', info='Here you can optionally write a description of yourself.', placeholder='{{user}}\'s personality: ...', elem_classes=['add_scrollbar'], elem_id="user-description") shared.gradio['user_bio'] = gr.Textbox(value=shared.settings['user_bio'], lines=10, label='Description', info='Here you can optionally write a description of yourself.', placeholder='{{user}}\'s personality: ...', elem_classes=['add_scrollbar'], elem_id="user-description")
@ -191,7 +169,7 @@ def create_character_settings_ui():
with gr.Column(scale=1): with gr.Column(scale=1):
shared.gradio['character_picture'] = gr.Image(label='Character picture', type='filepath', interactive=not mu) shared.gradio['character_picture'] = gr.Image(label='Character picture', type='filepath', interactive=not mu)
shared.gradio['your_picture'] = gr.Image(label='Your picture', type='filepath', value=Image.open(shared.user_data_dir / 'cache' / 'pfp_me.png') if (shared.user_data_dir / 'cache' / 'pfp_me.png').exists() else None, interactive=not mu) shared.gradio['your_picture'] = gr.Image(label='Your picture', type='filepath', value=Image.open(Path('user_data/cache/pfp_me.png')) if Path('user_data/cache/pfp_me.png').exists() else None, interactive=not mu)
def create_chat_settings_ui(): def create_chat_settings_ui():
@ -291,10 +269,6 @@ def create_event_handlers():
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.handle_start_new_chat_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False) chat.handle_start_new_chat_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False)
shared.gradio['Start incognito chat'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.handle_start_incognito_chat_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False)
shared.gradio['delete_chat-confirm'].click( shared.gradio['delete_chat-confirm'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.handle_delete_chat_confirm_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False) chat.handle_delete_chat_confirm_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False)
@ -350,13 +324,13 @@ def create_event_handlers():
shared.gradio['load_template'].click(chat.handle_load_template_click, gradio('instruction_template'), gradio('instruction_template_str', 'instruction_template'), show_progress=False) shared.gradio['load_template'].click(chat.handle_load_template_click, gradio('instruction_template'), gradio('instruction_template_str', 'instruction_template'), show_progress=False)
shared.gradio['save_template'].click( shared.gradio['save_template'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.handle_save_template_click, gradio('instruction_template_str'), gradio('save_filename', 'save_root', 'save_contents', 'save_root_state', 'file_saver'), show_progress=False) chat.handle_save_template_click, gradio('instruction_template_str'), gradio('save_filename', 'save_root', 'save_contents', 'file_saver'), show_progress=False)
shared.gradio['restore_character'].click( shared.gradio['restore_character'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.restore_character_for_ui, gradio('interface_state'), gradio('interface_state', 'name2', 'context', 'greeting', 'character_picture'), show_progress=False) chat.restore_character_for_ui, gradio('interface_state'), gradio('interface_state', 'name2', 'context', 'greeting', 'character_picture'), show_progress=False)
shared.gradio['delete_template'].click(chat.handle_delete_template_click, gradio('instruction_template'), gradio('delete_filename', 'delete_root', 'delete_root_state', 'file_deleter'), show_progress=False) shared.gradio['delete_template'].click(chat.handle_delete_template_click, gradio('instruction_template'), gradio('delete_filename', 'delete_root', 'file_deleter'), show_progress=False)
shared.gradio['save_chat_history'].click( shared.gradio['save_chat_history'].click(
lambda x: json.dumps(x, indent=4), gradio('history'), gradio('temporary_text')).then( lambda x: json.dumps(x, indent=4), gradio('history'), gradio('temporary_text')).then(
None, gradio('temporary_text', 'character_menu', 'mode'), None, js=f'(hist, char, mode) => {{{ui.save_files_js}; saveHistory(hist, char, mode)}}') None, gradio('temporary_text', 'character_menu', 'mode'), None, js=f'(hist, char, mode) => {{{ui.save_files_js}; saveHistory(hist, char, mode)}}')
@ -398,11 +372,3 @@ def create_event_handlers():
gradio('enable_web_search'), gradio('enable_web_search'),
gradio('web_search_row') gradio('web_search_row')
) )
# User menu event handlers
shared.gradio['user_menu'].change(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.handle_user_menu_change, gradio('interface_state'), gradio('name1', 'user_bio', 'your_picture'), show_progress=False)
shared.gradio['save_user'].click(chat.handle_save_user_click, gradio('name1'), gradio('save_user_filename', 'user_saver'), show_progress=False)
shared.gradio['delete_user'].click(lambda: gr.update(visible=True), None, gradio('user_deleter'), show_progress=False)

View file

@ -159,7 +159,7 @@ def handle_new_prompt():
new_name = utils.current_time() new_name = utils.current_time()
# Create the new prompt file # Create the new prompt file
prompt_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt" prompt_path = Path("user_data/logs/notebook") / f"{new_name}.txt"
prompt_path.parent.mkdir(parents=True, exist_ok=True) prompt_path.parent.mkdir(parents=True, exist_ok=True)
prompt_path.write_text("In this story,", encoding='utf-8') prompt_path.write_text("In this story,", encoding='utf-8')
@ -170,15 +170,15 @@ def handle_delete_prompt_confirm_default(prompt_name):
available_prompts = utils.get_available_prompts() available_prompts = utils.get_available_prompts()
current_index = available_prompts.index(prompt_name) if prompt_name in available_prompts else 0 current_index = available_prompts.index(prompt_name) if prompt_name in available_prompts else 0
(shared.user_data_dir / "logs" / "notebook" / f"{prompt_name}.txt").unlink(missing_ok=True) (Path("user_data/logs/notebook") / f"{prompt_name}.txt").unlink(missing_ok=True)
available_prompts = utils.get_available_prompts() available_prompts = utils.get_available_prompts()
if available_prompts: if available_prompts:
new_value = available_prompts[min(current_index, len(available_prompts) - 1)] new_value = available_prompts[min(current_index, len(available_prompts) - 1)]
else: else:
new_value = utils.current_time() new_value = utils.current_time()
(shared.user_data_dir / "logs" / "notebook").mkdir(parents=True, exist_ok=True) Path("user_data/logs/notebook").mkdir(parents=True, exist_ok=True)
(shared.user_data_dir / "logs" / "notebook" / f"{new_value}.txt").write_text("In this story,") (Path("user_data/logs/notebook") / f"{new_value}.txt").write_text("In this story,")
available_prompts = [new_value] available_prompts = [new_value]
return [ return [
@ -199,8 +199,8 @@ def handle_rename_prompt_click_default(current_name):
def handle_rename_prompt_confirm_default(new_name, current_name): def handle_rename_prompt_confirm_default(new_name, current_name):
old_path = shared.user_data_dir / "logs" / "notebook" / f"{current_name}.txt" old_path = Path("user_data/logs/notebook") / f"{current_name}.txt"
new_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt" new_path = Path("user_data/logs/notebook") / f"{new_name}.txt"
if old_path.exists() and not new_path.exists(): if old_path.exists() and not new_path.exists():
old_path.rename(new_path) old_path.rename(new_path)

View file

@ -3,18 +3,12 @@ import traceback
import gradio as gr import gradio as gr
from modules import chat, presets, shared, ui, utils from modules import chat, presets, shared, ui, utils
from modules.utils import gradio, sanitize_filename from modules.utils import gradio
def create_ui(): def create_ui():
mu = shared.args.multi_user mu = shared.args.multi_user
# Server-side per-session root paths for the generic file saver/deleter.
# Set by the handler that opens the dialog, read by the confirm handler.
# Using gr.State so they are session-scoped and safe for multi-user.
shared.gradio['save_root_state'] = gr.State(None)
shared.gradio['delete_root_state'] = gr.State(None)
# Text file saver # Text file saver
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['file_saver']: with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['file_saver']:
shared.gradio['save_filename'] = gr.Textbox(lines=1, label='File name') shared.gradio['save_filename'] = gr.Textbox(lines=1, label='File name')
@ -34,7 +28,7 @@ def create_ui():
# Character saver/deleter # Character saver/deleter
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['character_saver']: with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['character_saver']:
shared.gradio['save_character_filename'] = gr.Textbox(lines=1, label='File name', info=f'The character will be saved to your {shared.user_data_dir}/characters folder with this base filename.') shared.gradio['save_character_filename'] = gr.Textbox(lines=1, label='File name', info='The character will be saved to your user_data/characters folder with this base filename.')
with gr.Row(): with gr.Row():
shared.gradio['save_character_cancel'] = gr.Button('Cancel', elem_classes="small-button") shared.gradio['save_character_cancel'] = gr.Button('Cancel', elem_classes="small-button")
shared.gradio['save_character_confirm'] = gr.Button('Save', elem_classes="small-button", variant='primary', interactive=not mu) shared.gradio['save_character_confirm'] = gr.Button('Save', elem_classes="small-button", variant='primary', interactive=not mu)
@ -45,22 +39,9 @@ def create_ui():
shared.gradio['delete_character_cancel'] = gr.Button('Cancel', elem_classes="small-button") shared.gradio['delete_character_cancel'] = gr.Button('Cancel', elem_classes="small-button")
shared.gradio['delete_character_confirm'] = gr.Button('Delete', elem_classes="small-button", variant='stop', interactive=not mu) shared.gradio['delete_character_confirm'] = gr.Button('Delete', elem_classes="small-button", variant='stop', interactive=not mu)
# User saver/deleter
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['user_saver']:
shared.gradio['save_user_filename'] = gr.Textbox(lines=1, label='File name', info=f'The user profile will be saved to your {shared.user_data_dir}/users folder with this base filename.')
with gr.Row():
shared.gradio['save_user_cancel'] = gr.Button('Cancel', elem_classes="small-button")
shared.gradio['save_user_confirm'] = gr.Button('Save', elem_classes="small-button", variant='primary', interactive=not mu)
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['user_deleter']:
gr.Markdown('Confirm the user deletion?')
with gr.Row():
shared.gradio['delete_user_cancel'] = gr.Button('Cancel', elem_classes="small-button")
shared.gradio['delete_user_confirm'] = gr.Button('Delete', elem_classes="small-button", variant='stop', interactive=not mu)
# Preset saver # Preset saver
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['preset_saver']: with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['preset_saver']:
shared.gradio['save_preset_filename'] = gr.Textbox(lines=1, label='File name', info=f'The preset will be saved to your {shared.user_data_dir}/presets folder with this base filename.') shared.gradio['save_preset_filename'] = gr.Textbox(lines=1, label='File name', info='The preset will be saved to your user_data/presets folder with this base filename.')
shared.gradio['save_preset_contents'] = gr.Textbox(lines=10, label='File contents') shared.gradio['save_preset_contents'] = gr.Textbox(lines=10, label='File contents')
with gr.Row(): with gr.Row():
shared.gradio['save_preset_cancel'] = gr.Button('Cancel', elem_classes="small-button") shared.gradio['save_preset_cancel'] = gr.Button('Cancel', elem_classes="small-button")
@ -72,13 +53,13 @@ def create_event_handlers():
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
handle_save_preset_click, gradio('interface_state'), gradio('save_preset_contents', 'save_preset_filename', 'preset_saver'), show_progress=False) handle_save_preset_click, gradio('interface_state'), gradio('save_preset_contents', 'save_preset_filename', 'preset_saver'), show_progress=False)
shared.gradio['delete_preset'].click(handle_delete_preset_click, gradio('preset_menu'), gradio('delete_filename', 'delete_root', 'delete_root_state', 'file_deleter'), show_progress=False) shared.gradio['delete_preset'].click(handle_delete_preset_click, gradio('preset_menu'), gradio('delete_filename', 'delete_root', 'file_deleter'), show_progress=False)
shared.gradio['save_grammar'].click(handle_save_grammar_click, gradio('grammar_string'), gradio('save_contents', 'save_filename', 'save_root', 'save_root_state', 'file_saver'), show_progress=False) shared.gradio['save_grammar'].click(handle_save_grammar_click, gradio('grammar_string'), gradio('save_contents', 'save_filename', 'save_root', 'file_saver'), show_progress=False)
shared.gradio['delete_grammar'].click(handle_delete_grammar_click, gradio('grammar_file'), gradio('delete_filename', 'delete_root', 'delete_root_state', 'file_deleter'), show_progress=False) shared.gradio['delete_grammar'].click(handle_delete_grammar_click, gradio('grammar_file'), gradio('delete_filename', 'delete_root', 'file_deleter'), show_progress=False)
shared.gradio['save_preset_confirm'].click(handle_save_preset_confirm_click, gradio('save_preset_filename', 'save_preset_contents'), gradio('preset_menu', 'preset_saver'), show_progress=False) shared.gradio['save_preset_confirm'].click(handle_save_preset_confirm_click, gradio('save_preset_filename', 'save_preset_contents'), gradio('preset_menu', 'preset_saver'), show_progress=False)
shared.gradio['save_confirm'].click(handle_save_confirm_click, gradio('save_root_state', 'save_filename', 'save_contents'), gradio('save_root_state', 'file_saver'), show_progress=False) shared.gradio['save_confirm'].click(handle_save_confirm_click, gradio('save_root', 'save_filename', 'save_contents'), gradio('file_saver'), show_progress=False)
shared.gradio['delete_confirm'].click(handle_delete_confirm_click, gradio('delete_root_state', 'delete_filename'), gradio('delete_root_state', 'file_deleter'), show_progress=False) shared.gradio['delete_confirm'].click(handle_delete_confirm_click, gradio('delete_root', 'delete_filename'), gradio('file_deleter'), show_progress=False)
shared.gradio['save_character_confirm'].click(handle_save_character_confirm_click, gradio('name2', 'greeting', 'context', 'character_picture', 'save_character_filename'), gradio('character_menu', 'character_saver'), show_progress=False) shared.gradio['save_character_confirm'].click(handle_save_character_confirm_click, gradio('name2', 'greeting', 'context', 'character_picture', 'save_character_filename'), gradio('character_menu', 'character_saver'), show_progress=False)
shared.gradio['delete_character_confirm'].click(handle_delete_character_confirm_click, gradio('character_menu'), gradio('character_menu', 'character_deleter'), show_progress=False) shared.gradio['delete_character_confirm'].click(handle_delete_character_confirm_click, gradio('character_menu'), gradio('character_menu', 'character_deleter'), show_progress=False)
@ -88,17 +69,10 @@ def create_event_handlers():
shared.gradio['save_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_saver'), show_progress=False) shared.gradio['save_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_saver'), show_progress=False)
shared.gradio['delete_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_deleter'), show_progress=False) shared.gradio['delete_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_deleter'), show_progress=False)
# User save/delete event handlers
shared.gradio['save_user_confirm'].click(handle_save_user_confirm_click, gradio('name1', 'user_bio', 'your_picture', 'save_user_filename'), gradio('user_menu', 'user_saver'), show_progress=False)
shared.gradio['delete_user_confirm'].click(handle_delete_user_confirm_click, gradio('user_menu'), gradio('user_menu', 'user_deleter'), show_progress=False)
shared.gradio['save_user_cancel'].click(lambda: gr.update(visible=False), None, gradio('user_saver'), show_progress=False)
shared.gradio['delete_user_cancel'].click(lambda: gr.update(visible=False), None, gradio('user_deleter'), show_progress=False)
def handle_save_preset_confirm_click(filename, contents): def handle_save_preset_confirm_click(filename, contents):
try: try:
filename = sanitize_filename(filename) utils.save_file(f"user_data/presets/{filename}.yaml", contents)
utils.save_file(str(shared.user_data_dir / "presets" / f"{filename}.yaml"), contents)
available_presets = utils.get_available_presets() available_presets = utils.get_available_presets()
output = gr.update(choices=available_presets, value=filename) output = gr.update(choices=available_presets, value=filename)
except Exception: except Exception:
@ -111,30 +85,22 @@ def handle_save_preset_confirm_click(filename, contents):
] ]
def handle_save_confirm_click(root_state, filename, contents): def handle_save_confirm_click(root, filename, contents):
try: try:
if root_state is None: utils.save_file(root + filename, contents)
return None, gr.update(visible=False)
filename = sanitize_filename(filename)
utils.save_file(root_state + filename, contents)
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
return None, gr.update(visible=False) return gr.update(visible=False)
def handle_delete_confirm_click(root_state, filename): def handle_delete_confirm_click(root, filename):
try: try:
if root_state is None: utils.delete_file(root + filename)
return None, gr.update(visible=False)
filename = sanitize_filename(filename)
utils.delete_file(root_state + filename)
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
return None, gr.update(visible=False) return gr.update(visible=False)
def handle_save_character_confirm_click(name2, greeting, context, character_picture, filename): def handle_save_character_confirm_click(name2, greeting, context, character_picture, filename):
@ -177,61 +143,25 @@ def handle_save_preset_click(state):
def handle_delete_preset_click(preset): def handle_delete_preset_click(preset):
root = str(shared.user_data_dir / "presets") + "/"
return [ return [
f"{preset}.yaml", f"{preset}.yaml",
root, "user_data/presets/",
root,
gr.update(visible=True) gr.update(visible=True)
] ]
def handle_save_grammar_click(grammar_string): def handle_save_grammar_click(grammar_string):
root = str(shared.user_data_dir / "grammars") + "/"
return [ return [
grammar_string, grammar_string,
"My Fancy Grammar.gbnf", "My Fancy Grammar.gbnf",
root, "user_data/grammars/",
root,
gr.update(visible=True) gr.update(visible=True)
] ]
def handle_delete_grammar_click(grammar_file): def handle_delete_grammar_click(grammar_file):
root = str(shared.user_data_dir / "grammars") + "/"
return [ return [
grammar_file, grammar_file,
root, "user_data/grammars/",
root,
gr.update(visible=True) gr.update(visible=True)
] ]
def handle_save_user_confirm_click(name1, user_bio, your_picture, filename):
try:
chat.save_user(name1, user_bio, your_picture, filename)
available_users = utils.get_available_users()
output = gr.update(choices=available_users, value=filename)
except Exception:
output = gr.update()
traceback.print_exc()
return [
output,
gr.update(visible=False)
]
def handle_delete_user_confirm_click(user):
try:
index = str(utils.get_available_users().index(user))
chat.delete_user(user)
output = chat.update_user_menu_after_deletion(index)
except Exception:
output = gr.update()
traceback.print_exc()
return [
output,
gr.update(visible=False)
]

View file

@ -138,7 +138,7 @@ def save_generated_images(images, state, actual_seed):
return [] return []
date_str = datetime.now().strftime("%Y-%m-%d") date_str = datetime.now().strftime("%Y-%m-%d")
folder_path = str(shared.user_data_dir / "image_outputs" / date_str) folder_path = os.path.join("user_data", "image_outputs", date_str)
os.makedirs(folder_path, exist_ok=True) os.makedirs(folder_path, exist_ok=True)
metadata = build_generation_metadata(state, actual_seed) metadata = build_generation_metadata(state, actual_seed)
@ -214,7 +214,7 @@ def get_all_history_images(force_refresh=False):
"""Get all history images sorted by modification time (newest first). Uses caching.""" """Get all history images sorted by modification time (newest first). Uses caching."""
global _image_cache, _cache_timestamp global _image_cache, _cache_timestamp
output_dir = str(shared.user_data_dir / "image_outputs") output_dir = os.path.join("user_data", "image_outputs")
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
return [] return []
@ -728,8 +728,6 @@ def generate_prompt_variation(state):
variation = variation.rsplit("</think>", 1)[1] variation = variation.rsplit("</think>", 1)[1]
elif "<|start|>assistant<|channel|>final<|message|>" in variation: elif "<|start|>assistant<|channel|>final<|message|>" in variation:
variation = variation.rsplit("<|start|>assistant<|channel|>final<|message|>", 1)[1] variation = variation.rsplit("<|start|>assistant<|channel|>final<|message|>", 1)[1]
elif "<|channel|>final<|message|>" in variation:
variation = variation.rsplit("<|channel|>final<|message|>", 1)[1]
elif "</seed:think>" in variation: elif "</seed:think>" in variation:
variation = variation.rsplit("</seed:think>", 1)[1] variation = variation.rsplit("</seed:think>", 1)[1]

View file

@ -41,12 +41,11 @@ def create_ui():
gr.Markdown("## Main options") gr.Markdown("## Main options")
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
shared.gradio['gpu_layers'] = gr.Slider(label="gpu-layers", minimum=-1, maximum=get_initial_gpu_layers_max(), step=1, value=shared.args.gpu_layers, info='Number of layers to offload to the GPU. -1 = auto.') shared.gradio['gpu_layers'] = gr.Slider(label="gpu-layers", minimum=0, maximum=get_initial_gpu_layers_max(), step=1, value=shared.args.gpu_layers, info='Must be greater than 0 for the GPU to be used. ⚠️ Lower this value if you can\'t load the model.')
shared.gradio['ctx_size'] = gr.Slider(label='ctx-size', minimum=0, maximum=1048576, step=1024, value=shared.args.ctx_size, info='Context length. 0 = auto for llama.cpp (requires gpu-layers=-1), 8192 for other loaders. Common values: 4096, 8192, 16384, 32768, 65536, 131072.') shared.gradio['ctx_size'] = gr.Slider(label='ctx-size', minimum=256, maximum=131072, step=256, value=shared.args.ctx_size, info='Context length. Common values: 4096, 8192, 16384, 32768, 65536, 131072.')
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7') shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
shared.gradio['attn_implementation'] = gr.Dropdown(label="attn-implementation", choices=['sdpa', 'eager', 'flash_attention_2'], value=shared.args.attn_implementation, info='Attention implementation.') shared.gradio['attn_implementation'] = gr.Dropdown(label="attn-implementation", choices=['sdpa', 'eager', 'flash_attention_2'], value=shared.args.attn_implementation, info='Attention implementation.')
shared.gradio['cache_type'] = gr.Dropdown(label="cache-type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q7', 'q6', 'q5', 'q4', 'q3', 'q2'], value=shared.args.cache_type, allow_custom_value=True, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8. For ExLlamaV3, you can type custom combinations for separate k/v bits (e.g. q4_q8).') shared.gradio['cache_type'] = gr.Dropdown(label="cache-type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q7', 'q6', 'q5', 'q4', 'q3', 'q2'], value=shared.args.cache_type, allow_custom_value=True, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4; ExLlamaV3 - fp16, q2 to q8. For ExLlamaV3, you can type custom combinations for separate k/v bits (e.g. q4_q8).')
shared.gradio['fit_target'] = gr.Textbox(label='fit-target', value=shared.args.fit_target, info='Target VRAM margin per device for auto GPU layers (MiB). Comma-separated list for multiple devices.')
shared.gradio['tp_backend'] = gr.Dropdown(label="tp-backend", choices=['native', 'nccl'], value=shared.args.tp_backend, info='The backend for tensor parallelism.') shared.gradio['tp_backend'] = gr.Dropdown(label="tp-backend", choices=['native', 'nccl'], value=shared.args.tp_backend, info='The backend for tensor parallelism.')
with gr.Column(): with gr.Column():
@ -56,43 +55,32 @@ def create_ui():
shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit) shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit) shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit)
shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant, info='Used by load-in-4bit.') shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant, info='Used by load-in-4bit.')
shared.gradio['autosplit'] = gr.Checkbox(label="autosplit", value=shared.args.autosplit, info='Automatically split the model tensors across the available GPUs.')
shared.gradio['enable_tp'] = gr.Checkbox(label="enable_tp", value=shared.args.enable_tp, info='Enable tensor parallelism (TP).') shared.gradio['enable_tp'] = gr.Checkbox(label="enable_tp", value=shared.args.enable_tp, info='Enable tensor parallelism (TP).')
shared.gradio['tensorrt_llm_info'] = gr.Markdown( shared.gradio['cpp_runner'] = gr.Checkbox(label="cpp-runner", value=shared.args.cpp_runner, info='Enable inference with ModelRunnerCpp, which is faster than the default ModelRunner.')
'* TensorRT-LLM has to be installed manually: `pip install tensorrt_llm==1.1.0 --extra-index-url https://pypi.nvidia.com`.\n\n' shared.gradio['tensorrt_llm_info'] = gr.Markdown('* TensorRT-LLM has to be installed manually in a separate Python 3.10 environment at the moment. For a guide, consult the description of [this PR](https://github.com/oobabooga/text-generation-webui/pull/5715). \n\n* `ctx_size` is only used when `cpp-runner` is checked.\n\n* `cpp_runner` does not support streaming at the moment.')
'* You can load either a pre-built TensorRT engine or a regular HF model. '
'HF models will be compiled to a TensorRT engine automatically on each load (this can take a while).'
)
# Multimodal # Multimodal
with gr.Accordion("Multimodal (vision)", open=False, elem_classes='tgw-accordion') as shared.gradio['mmproj_accordion']: with gr.Accordion("Multimodal (vision)", open=False, elem_classes='tgw-accordion') as shared.gradio['mmproj_accordion']:
with gr.Row(): with gr.Row():
shared.gradio['mmproj'] = gr.Dropdown(label="mmproj file", choices=utils.get_available_mmproj(), value=lambda: shared.args.mmproj or 'None', elem_classes='slim-dropdown', info=f'Select a file that matches your model. Must be placed in {shared.user_data_dir}/mmproj/', interactive=not mu) shared.gradio['mmproj'] = gr.Dropdown(label="mmproj file", choices=utils.get_available_mmproj(), value=lambda: shared.args.mmproj or 'None', elem_classes='slim-dropdown', info='Select a file that matches your model. Must be placed in user_data/mmproj/', interactive=not mu)
ui.create_refresh_button(shared.gradio['mmproj'], lambda: None, lambda: {'choices': utils.get_available_mmproj()}, 'refresh-button', interactive=not mu) ui.create_refresh_button(shared.gradio['mmproj'], lambda: None, lambda: {'choices': utils.get_available_mmproj()}, 'refresh-button', interactive=not mu)
# Speculative decoding # Speculative decoding
with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']: with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']:
shared.gradio['draft_max'] = gr.Number(label="draft-max", precision=0, step=1, value=shared.args.draft_max, info='Maximum number of tokens to draft for speculative decoding. Recommended: 4 for draft model, 64 for n-gram.')
gr.Markdown('#### Draft model')
with gr.Row(): with gr.Row():
shared.gradio['model_draft'] = gr.Dropdown(label="model-draft", choices=['None'] + utils.get_available_models(), value=lambda: shared.args.model_draft, elem_classes='slim-dropdown', info='Draft model. Must share the same vocabulary as the main model.', interactive=not mu) shared.gradio['model_draft'] = gr.Dropdown(label="model-draft", choices=['None'] + utils.get_available_models(), value=lambda: shared.args.model_draft, elem_classes='slim-dropdown', info='Draft model. Speculative decoding only works with models sharing the same vocabulary (e.g., same model family).', interactive=not mu)
ui.create_refresh_button(shared.gradio['model_draft'], lambda: None, lambda: {'choices': ['None'] + utils.get_available_models()}, 'refresh-button', interactive=not mu) ui.create_refresh_button(shared.gradio['model_draft'], lambda: None, lambda: {'choices': ['None'] + utils.get_available_models()}, 'refresh-button', interactive=not mu)
shared.gradio['gpu_layers_draft'] = gr.Slider(label="gpu-layers-draft", minimum=0, maximum=256, value=shared.args.gpu_layers_draft, info='Number of layers to offload to the GPU for the draft model.') shared.gradio['gpu_layers_draft'] = gr.Slider(label="gpu-layers-draft", minimum=0, maximum=256, value=shared.args.gpu_layers_draft, info='Number of layers to offload to the GPU for the draft model.')
shared.gradio['draft_max'] = gr.Number(label="draft-max", precision=0, step=1, value=shared.args.draft_max, info='Number of tokens to draft for speculative decoding. Recommended value: 4.')
shared.gradio['device_draft'] = gr.Textbox(label="device-draft", value=shared.args.device_draft, info='Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1') shared.gradio['device_draft'] = gr.Textbox(label="device-draft", value=shared.args.device_draft, info='Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1')
shared.gradio['ctx_size_draft'] = gr.Number(label="ctx-size-draft", precision=0, step=256, value=shared.args.ctx_size_draft, info='Size of the prompt context for the draft model. If 0, uses the same as the main model.') shared.gradio['ctx_size_draft'] = gr.Number(label="ctx-size-draft", precision=0, step=256, value=shared.args.ctx_size_draft, info='Size of the prompt context for the draft model. If 0, uses the same as the main model.')
shared.gradio['ngram_header'] = gr.Markdown('#### N-gram (draftless)')
shared.gradio['spec_type'] = gr.Dropdown(label="spec-type", choices=['none', 'ngram-mod', 'ngram-simple', 'ngram-map-k', 'ngram-map-k4v', 'ngram-cache'], value=shared.args.spec_type, info='Draftless speculative decoding type. Recommended: ngram-mod.')
shared.gradio['spec_ngram_size_n'] = gr.Number(label="spec-ngram-size-n", precision=0, step=1, value=shared.args.spec_ngram_size_n, info='N-gram lookup size for speculative decoding.', visible=shared.args.spec_type != 'none')
shared.gradio['spec_ngram_size_m'] = gr.Number(label="spec-ngram-size-m", precision=0, step=1, value=shared.args.spec_ngram_size_m, info='Draft n-gram size for speculative decoding.', visible=shared.args.spec_type != 'none')
shared.gradio['spec_ngram_min_hits'] = gr.Number(label="spec-ngram-min-hits", precision=0, step=1, value=shared.args.spec_ngram_min_hits, info='Minimum n-gram hits for ngram-map speculative decoding.', visible=shared.args.spec_type != 'none')
gr.Markdown("## Other options") gr.Markdown("## Other options")
with gr.Accordion("See more options", open=False, elem_classes='tgw-accordion'): with gr.Accordion("See more options", open=False, elem_classes='tgw-accordion'):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
shared.gradio['parallel'] = gr.Slider(label="parallel", minimum=1, step=1, maximum=64, value=shared.args.parallel, info='Number of parallel request slots for the API. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set ctx_size to 32768.')
shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=256, value=shared.args.threads) shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=256, value=shared.args.threads)
shared.gradio['threads_batch'] = gr.Slider(label="threads_batch", minimum=0, step=1, maximum=256, value=shared.args.threads_batch) shared.gradio['threads_batch'] = gr.Slider(label="threads_batch", minimum=0, step=1, maximum=256, value=shared.args.threads_batch)
shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size) shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size)
@ -100,8 +88,12 @@ def create_ui():
shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40') shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')
shared.gradio['extra_flags'] = gr.Textbox(label='extra-flags', info='Additional flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"', value=shared.args.extra_flags) shared.gradio['extra_flags'] = gr.Textbox(label='extra-flags', info='Additional flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"', value=shared.args.extra_flags)
shared.gradio['cpu_memory'] = gr.Number(label="Maximum CPU memory in GiB. Use this for CPU offloading.", value=shared.args.cpu_memory) shared.gradio['cpu_memory'] = gr.Number(label="Maximum CPU memory in GiB. Use this for CPU offloading.", value=shared.args.cpu_memory)
shared.gradio['alpha_value'] = gr.Number(label='alpha_value', value=shared.args.alpha_value, precision=2, info='Positional embeddings alpha factor for NTK RoPE scaling. Recommended values (NTKv1): 1.75 for 1.5x context, 2.5 for 2x context. Use either this or compress_pos_emb, not both.')
shared.gradio['rope_freq_base'] = gr.Number(label='rope_freq_base', value=shared.args.rope_freq_base, precision=0, info='Positional embeddings frequency base for NTK RoPE scaling. Related to alpha_value by rope_freq_base = 10000 * alpha_value ^ (64 / 63). 0 = from model.')
shared.gradio['compress_pos_emb'] = gr.Number(label='compress_pos_emb', value=shared.args.compress_pos_emb, precision=2, info='Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.')
shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype, info='Used by load-in-4bit.') shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype, info='Used by load-in-4bit.')
shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type, info='Used by load-in-4bit.') shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type, info='Used by load-in-4bit.')
shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.')
with gr.Column(): with gr.Column():
shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='Use PyTorch in CPU mode.') shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='Use PyTorch in CPU mode.')
@ -112,6 +104,9 @@ def create_ui():
shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock) shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock)
shared.gradio['numa'] = gr.Checkbox(label="numa", value=shared.args.numa, info='NUMA support can help on some systems with non-uniform memory access.') shared.gradio['numa'] = gr.Checkbox(label="numa", value=shared.args.numa, info='NUMA support can help on some systems with non-uniform memory access.')
shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16) shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16)
shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn)
shared.gradio['no_xformers'] = gr.Checkbox(label="no_xformers", value=shared.args.no_xformers)
shared.gradio['no_sdpa'] = gr.Checkbox(label="no_sdpa", value=shared.args.no_sdpa)
shared.gradio['cfg_cache'] = gr.Checkbox(label="cfg-cache", value=shared.args.cfg_cache, info='Necessary to use CFG with this loader.') shared.gradio['cfg_cache'] = gr.Checkbox(label="cfg-cache", value=shared.args.cfg_cache, info='Necessary to use CFG with this loader.')
shared.gradio['no_use_fast'] = gr.Checkbox(label="no_use_fast", value=shared.args.no_use_fast, info='Set use_fast=False while loading the tokenizer.') shared.gradio['no_use_fast'] = gr.Checkbox(label="no_use_fast", value=shared.args.no_use_fast, info='Set use_fast=False while loading the tokenizer.')
if not shared.args.portable: if not shared.args.portable:
@ -162,35 +157,28 @@ def create_event_handlers():
handle_load_model_event_final, gradio('truncation_length', 'loader', 'interface_state'), gradio('truncation_length', 'filter_by_loader'), show_progress=False) handle_load_model_event_final, gradio('truncation_length', 'loader', 'interface_state'), gradio('truncation_length', 'filter_by_loader'), show_progress=False)
shared.gradio['unload_model'].click(handle_unload_model_click, None, gradio('model_status'), show_progress=False).then( shared.gradio['unload_model'].click(handle_unload_model_click, None, gradio('model_status'), show_progress=False).then(
update_gpu_layers_and_vram, gradio('loader', 'model_menu', 'gpu_layers', 'ctx_size', 'cache_type'), gradio('vram_info'), show_progress=False) partial(update_gpu_layers_and_vram, auto_adjust=True), gradio('loader', 'model_menu', 'gpu_layers', 'ctx_size', 'cache_type'), gradio('vram_info', 'gpu_layers'), show_progress=False)
shared.gradio['save_model_settings'].click( shared.gradio['save_model_settings'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
save_model_settings, gradio('model_menu', 'interface_state'), gradio('model_status'), show_progress=False) save_model_settings, gradio('model_menu', 'interface_state'), gradio('model_status'), show_progress=False)
# For ctx_size and cache_type - update VRAM display # For ctx_size and cache_type - auto-adjust GPU layers
for param in ['ctx_size', 'cache_type']: for param in ['ctx_size', 'cache_type']:
shared.gradio[param].change( shared.gradio[param].change(
update_gpu_layers_and_vram, partial(update_gpu_layers_and_vram, auto_adjust=True),
gradio('loader', 'model_menu', 'gpu_layers', 'ctx_size', 'cache_type'), gradio('loader', 'model_menu', 'gpu_layers', 'ctx_size', 'cache_type'),
gradio('vram_info'), show_progress=False) gradio('vram_info', 'gpu_layers'), show_progress=False)
# For manual gpu_layers changes - only update VRAM # For manual gpu_layers changes - only update VRAM
shared.gradio['gpu_layers'].change( shared.gradio['gpu_layers'].change(
update_gpu_layers_and_vram, partial(update_gpu_layers_and_vram, auto_adjust=False),
gradio('loader', 'model_menu', 'gpu_layers', 'ctx_size', 'cache_type'), gradio('loader', 'model_menu', 'gpu_layers', 'ctx_size', 'cache_type'),
gradio('vram_info'), show_progress=False) gradio('vram_info'), show_progress=False)
if not shared.args.portable: if not shared.args.portable:
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, gradio('lora_menu'), gradio('model_status'), show_progress=False) shared.gradio['lora_menu_apply'].click(load_lora_wrapper, gradio('lora_menu'), gradio('model_status'), show_progress=False)
shared.gradio['spec_type'].change(
lambda x: [gr.update(visible=x != 'none')] * 3,
gradio('spec_type'),
gradio('spec_ngram_size_n', 'spec_ngram_size_m', 'spec_ngram_min_hits'),
show_progress=False
)
shared.gradio['download_model_button'].click(download_model_wrapper, gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True) shared.gradio['download_model_button'].click(download_model_wrapper, gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True)
shared.gradio['get_file_list'].click(partial(download_model_wrapper, return_links=True), gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True) shared.gradio['get_file_list'].click(partial(download_model_wrapper, return_links=True), gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True)
shared.gradio['customized_template_submit'].click(save_instruction_template, gradio('model_menu', 'customized_template'), gradio('model_status'), show_progress=True) shared.gradio['customized_template_submit'].click(save_instruction_template, gradio('model_menu', 'customized_template'), gradio('model_status'), show_progress=True)
@ -221,7 +209,7 @@ def load_model_wrapper(selected_model, loader, autoload=False):
yield f"Successfully loaded `{selected_model}`." yield f"Successfully loaded `{selected_model}`."
else: else:
yield f"Failed to load `{selected_model}`." yield f"Failed to load `{selected_model}`."
except Exception: except:
exc = traceback.format_exc() exc = traceback.format_exc()
logger.error('Failed to load the model.') logger.error('Failed to load the model.')
print(exc) print(exc)
@ -315,9 +303,9 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
model_dir=shared.args.model_dir if shared.args.model_dir != shared.args_defaults.model_dir else None model_dir=shared.args.model_dir if shared.args.model_dir != shared.args_defaults.model_dir else None
) )
if output_folder == shared.user_data_dir / "models": if output_folder == Path("user_data/models"):
output_folder = Path(shared.args.model_dir) output_folder = Path(shared.args.model_dir)
elif output_folder == shared.user_data_dir / "loras": elif output_folder == Path("user_data/loras"):
output_folder = Path(shared.args.lora_dir) output_folder = Path(shared.args.lora_dir)
if check: if check:
@ -385,12 +373,8 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
def update_truncation_length(current_length, state): def update_truncation_length(current_length, state):
if 'loader' in state: if 'loader' in state:
if state['loader'].lower().startswith('exllama') or state['loader'] == 'llama.cpp': if state['loader'].lower().startswith('exllama') or state['loader'] == 'llama.cpp':
if state['ctx_size'] > 0:
return state['ctx_size'] return state['ctx_size']
# ctx_size == 0 means auto: use the actual value from the server
return shared.settings['truncation_length']
return current_length return current_length
@ -402,6 +386,8 @@ def get_initial_vram_info():
shared.args.gpu_layers, shared.args.gpu_layers,
shared.args.ctx_size, shared.args.ctx_size,
shared.args.cache_type, shared.args.cache_type,
auto_adjust=False,
for_ui=True
) )
return "<div id=\"vram-info\"'>Estimated VRAM to load the model:</div>" return "<div id=\"vram-info\"'>Estimated VRAM to load the model:</div>"
@ -410,7 +396,7 @@ def get_initial_vram_info():
def get_initial_gpu_layers_max(): def get_initial_gpu_layers_max():
if shared.model_name != 'None' and shared.args.loader == 'llama.cpp': if shared.model_name != 'None' and shared.args.loader == 'llama.cpp':
model_settings = get_model_metadata(shared.model_name) model_settings = get_model_metadata(shared.model_name)
return model_settings.get('max_gpu_layers', 256) return model_settings.get('max_gpu_layers', model_settings.get('gpu_layers', 256))
return 256 return 256

View file

@ -194,7 +194,7 @@ def handle_new_prompt():
new_name = utils.current_time() new_name = utils.current_time()
# Create the new prompt file # Create the new prompt file
prompt_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt" prompt_path = Path("user_data/logs/notebook") / f"{new_name}.txt"
prompt_path.parent.mkdir(parents=True, exist_ok=True) prompt_path.parent.mkdir(parents=True, exist_ok=True)
prompt_path.write_text("In this story,", encoding='utf-8') prompt_path.write_text("In this story,", encoding='utf-8')
@ -205,15 +205,15 @@ def handle_delete_prompt_confirm_notebook(prompt_name):
available_prompts = utils.get_available_prompts() available_prompts = utils.get_available_prompts()
current_index = available_prompts.index(prompt_name) if prompt_name in available_prompts else 0 current_index = available_prompts.index(prompt_name) if prompt_name in available_prompts else 0
(shared.user_data_dir / "logs" / "notebook" / f"{prompt_name}.txt").unlink(missing_ok=True) (Path("user_data/logs/notebook") / f"{prompt_name}.txt").unlink(missing_ok=True)
available_prompts = utils.get_available_prompts() available_prompts = utils.get_available_prompts()
if available_prompts: if available_prompts:
new_value = available_prompts[min(current_index, len(available_prompts) - 1)] new_value = available_prompts[min(current_index, len(available_prompts) - 1)]
else: else:
new_value = utils.current_time() new_value = utils.current_time()
(shared.user_data_dir / "logs" / "notebook").mkdir(parents=True, exist_ok=True) Path("user_data/logs/notebook").mkdir(parents=True, exist_ok=True)
(shared.user_data_dir / "logs" / "notebook" / f"{new_value}.txt").write_text("In this story,") (Path("user_data/logs/notebook") / f"{new_value}.txt").write_text("In this story,")
available_prompts = [new_value] available_prompts = [new_value]
return [ return [
@ -233,8 +233,8 @@ def handle_rename_prompt_click_notebook(current_name):
def handle_rename_prompt_confirm_notebook(new_name, current_name): def handle_rename_prompt_confirm_notebook(new_name, current_name):
old_path = shared.user_data_dir / "logs" / "notebook" / f"{current_name}.txt" old_path = Path("user_data/logs/notebook") / f"{current_name}.txt"
new_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt" new_path = Path("user_data/logs/notebook") / f"{new_name}.txt"
if old_path.exists() and not new_path.exists(): if old_path.exists() and not new_path.exists():
old_path.rename(new_path) old_path.rename(new_path)
@ -250,7 +250,7 @@ def handle_rename_prompt_confirm_notebook(new_name, current_name):
def autosave_prompt(text, prompt_name): def autosave_prompt(text, prompt_name):
"""Automatically save the text to the selected prompt file""" """Automatically save the text to the selected prompt file"""
if prompt_name and text.strip(): if prompt_name and text.strip():
prompt_path = shared.user_data_dir / "logs" / "notebook" / f"{prompt_name}.txt" prompt_path = Path("user_data/logs/notebook") / f"{prompt_name}.txt"
prompt_path.parent.mkdir(parents=True, exist_ok=True) prompt_path.parent.mkdir(parents=True, exist_ok=True)
prompt_path.write_text(text, encoding='utf-8') prompt_path.write_text(text, encoding='utf-8')

Some files were not shown because too many files have changed in this diff Show more