mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-20 04:14:38 +01:00
Compare commits
264 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
256431f258 | ||
|
|
88a318894c | ||
|
|
44810751de | ||
|
|
6c05a964a7 | ||
|
|
737ded6959 | ||
|
|
50685c93f2 | ||
|
|
9d9f5d9860 | ||
|
|
5cfe9fe295 | ||
|
|
b76a289e04 | ||
|
|
c0de1d176c | ||
|
|
4f80b20859 | ||
|
|
f8ff7cf99e | ||
|
|
92d376e420 | ||
|
|
f6a749a151 | ||
|
|
1a2b840938 | ||
|
|
bfea49b197 | ||
|
|
80d0c03bab | ||
|
|
9119ce0680 | ||
|
|
5763cab3c4 | ||
|
|
f0c16813ef | ||
|
|
2d3a3794c9 | ||
|
|
9955e54a1f | ||
|
|
d1aba08561 | ||
|
|
c126530061 | ||
|
|
b9bdbd638e | ||
|
|
9eacd4a207 | ||
|
|
e11425d5f8 | ||
|
|
4ae2bd86e2 | ||
|
|
9f657d3976 | ||
|
|
c09a367c64 | ||
|
|
beab346f48 | ||
|
|
573617157a | ||
|
|
d0a4993cf4 | ||
|
|
c7953fb923 | ||
|
|
c908ac00d7 | ||
|
|
8bff331893 | ||
|
|
cb08ba63dc | ||
|
|
09a6549816 | ||
|
|
accb2ef661 | ||
|
|
998b9bfb2a | ||
|
|
5f1707af35 | ||
|
|
16636c04b8 | ||
|
|
e8d1c66303 | ||
|
|
cb88066d15 | ||
|
|
0cd245bcbb | ||
|
|
24e7e77b55 | ||
|
|
cabb95f0d6 | ||
|
|
5362bbb413 | ||
|
|
d4c22ced83 | ||
|
|
aab2596d29 | ||
|
|
e0a38da9f3 | ||
|
|
e50b823eee | ||
|
|
b7670cc762 | ||
|
|
d0b72c73c0 | ||
|
|
c39c187f47 | ||
|
|
4628825651 | ||
|
|
fef95b9e56 | ||
|
|
5833d94d7f | ||
|
|
a4bef860b6 | ||
|
|
5ddc1002d2 | ||
|
|
c094bc943c | ||
|
|
85ec85e569 | ||
|
|
04213dff14 | ||
|
|
24fdcc52b3 | ||
|
|
58f26a4cc7 | ||
|
|
0e35421593 | ||
|
|
1ed56aee85 | ||
|
|
286ae475f6 | ||
|
|
4c7a56c18d | ||
|
|
a09f21b9de | ||
|
|
1b7e6c5705 | ||
|
|
f8936ec47c | ||
|
|
5c02b7f603 | ||
|
|
09d5e049d6 | ||
|
|
fdd8e5b1fd | ||
|
|
4f82b71ef3 | ||
|
|
bbd43d9463 | ||
|
|
3e6bd1a310 | ||
|
|
9a7428b627 | ||
|
|
2d0cc7726e | ||
|
|
d45c9b3c59 | ||
|
|
2466305f76 | ||
|
|
a916fb0e5c | ||
|
|
fb1b3b6ddf | ||
|
|
5a017aa338 | ||
|
|
4b6c9db1c9 | ||
|
|
09723c9988 | ||
|
|
2549f7c33b | ||
|
|
b5cac2e3b2 | ||
|
|
0d62038710 | ||
|
|
cf9ad8eafe | ||
|
|
980a9d1657 | ||
|
|
bb00d96dc3 | ||
|
|
66c976e995 | ||
|
|
24977846fb | ||
|
|
7a63a56043 | ||
|
|
f1cfeae372 | ||
|
|
3304b57bdf | ||
|
|
8aeaa76365 | ||
|
|
6ec4ca8b10 | ||
|
|
307c085d1b | ||
|
|
c604ca66de | ||
|
|
15792c3cb8 | ||
|
|
3b71932658 | ||
|
|
83b7e47d77 | ||
|
|
7f485274eb | ||
|
|
39e6c997cc | ||
|
|
970055ca00 | ||
|
|
d6643bb4bc | ||
|
|
9753b2342b | ||
|
|
eb4a20137a | ||
|
|
634609acca | ||
|
|
40f1837b42 | ||
|
|
f6ffecfff2 | ||
|
|
5a91b8462f | ||
|
|
7a8ca9f2b0 | ||
|
|
7170a16b91 | ||
|
|
b3705d87bf | ||
|
|
0132966d09 | ||
|
|
baf4e13ff1 | ||
|
|
6ff111d18e | ||
|
|
aeeff41cc0 | ||
|
|
0cecc0a041 | ||
|
|
e1bf0b866f | ||
|
|
3b7cf44406 | ||
|
|
b686193fe2 | ||
|
|
328215b0c7 | ||
|
|
304510eb3d | ||
|
|
085c4ef5d7 | ||
|
|
aa634c77c0 | ||
|
|
abc699db9b | ||
|
|
f2fe001cc4 | ||
|
|
7ea5513263 | ||
|
|
5fa709a3f4 | ||
|
|
e8e0d02406 | ||
|
|
1eead661c3 | ||
|
|
d48b53422f | ||
|
|
2beaa4b971 | ||
|
|
5f6754c267 | ||
|
|
b8b4471ab5 | ||
|
|
d03923924a | ||
|
|
044566d42d | ||
|
|
f5acf55207 | ||
|
|
3531069824 | ||
|
|
160f7ad6b4 | ||
|
|
8e24a20873 | ||
|
|
3bab7fbfd4 | ||
|
|
e7e0df0101 | ||
|
|
3323dedd08 | ||
|
|
36dbc4ccce | ||
|
|
86d59b4404 | ||
|
|
0e0e3ceb97 | ||
|
|
6d7018069c | ||
|
|
f9ed8820de | ||
|
|
3880c1a406 | ||
|
|
93ebfa2b7e | ||
|
|
d0ac58ad31 | ||
|
|
f06583b2b9 | ||
|
|
8be444a559 | ||
|
|
1729fb07b9 | ||
|
|
eba262d47a | ||
|
|
521ddbb722 | ||
|
|
66fb79fe15 | ||
|
|
e81a47f708 | ||
|
|
27bcc45c18 | ||
|
|
8a9afcbec6 | ||
|
|
2e7e966ef2 | ||
|
|
ddcad3cc51 | ||
|
|
8d43123f73 | ||
|
|
e2548f69a9 | ||
|
|
4c406e024f | ||
|
|
249bd6eea2 | ||
|
|
f52d9336e5 | ||
|
|
9824c82cb6 | ||
|
|
2f08dce7b0 | ||
|
|
134ac8fc29 | ||
|
|
409db3df1e | ||
|
|
86d8291e58 | ||
|
|
33ff3773a0 | ||
|
|
7a1fa8c9ea | ||
|
|
275810c843 | ||
|
|
438e59498e | ||
|
|
63f28cb4a2 | ||
|
|
33a38d7ece | ||
|
|
c2e494963f | ||
|
|
5b18be8582 | ||
|
|
d337ba0390 | ||
|
|
5be68cc073 | ||
|
|
1ffe540c97 | ||
|
|
1c2548fd89 | ||
|
|
da2d4f1a6a | ||
|
|
d278bb46a2 | ||
|
|
b16a1a874a | ||
|
|
45188eccef | ||
|
|
268cc3f100 | ||
|
|
69fa4dd0b1 | ||
|
|
fbfcd59fe0 | ||
|
|
d45aa6606a | ||
|
|
0804296f4d | ||
|
|
6a08e79fa5 | ||
|
|
ff48956cb0 | ||
|
|
5a22970ba8 | ||
|
|
387cf9d8df | ||
|
|
942ff8fcb4 | ||
|
|
da3010c3ed | ||
|
|
83cc207ef7 | ||
|
|
2ac4eb33c8 | ||
|
|
7bf15ad933 | ||
|
|
1d1f4dfc88 | ||
|
|
abb7cc02e9 | ||
|
|
68109bc5da | ||
|
|
952e2c404a | ||
|
|
cdf0e392e6 | ||
|
|
eb90daf098 | ||
|
|
0ffb75de7c | ||
|
|
d8af0505a8 | ||
|
|
9b916f02cd | ||
|
|
5d93f4e800 | ||
|
|
64eb77e782 | ||
|
|
22141679e3 | ||
|
|
65de4c30c8 | ||
|
|
f010aa1612 | ||
|
|
f4d787ab8d | ||
|
|
8a3d866401 | ||
|
|
11dc6fdfce | ||
|
|
7d42b6900e | ||
|
|
8cbb7661a8 | ||
|
|
866c48e55b | ||
|
|
b3fd0d16e0 | ||
|
|
d584ede72e | ||
|
|
c0bff831e3 | ||
|
|
2260e530c9 | ||
|
|
e9f22813e4 | ||
|
|
3519890c8e | ||
|
|
9c604628a0 | ||
|
|
fbd2acfa19 | ||
|
|
5fd79b23d1 | ||
|
|
b8fcc8ea32 | ||
|
|
d7dd533b99 | ||
|
|
9576c5a5f4 | ||
|
|
9814d3d0ae | ||
|
|
38d0eeefc0 | ||
|
|
ddd74324fe | ||
|
|
efc72d5c32 | ||
|
|
aecbc5a8ac | ||
|
|
c54e8a2b3d | ||
|
|
dc2bbf1861 | ||
|
|
cae1fef42d | ||
|
|
7493fe7841 | ||
|
|
21b979c02a | ||
|
|
a731861127 | ||
|
|
910456ba31 | ||
|
|
d79cdc614c | ||
|
|
332fd40653 | ||
|
|
50a35b483c | ||
|
|
45fbec0320 | ||
|
|
b0968ed8b4 | ||
|
|
36747cf99c | ||
|
|
2fcbadec67 | ||
|
|
bb3b7bc197 | ||
|
|
6e2c4e9c23 | ||
|
|
a2ed640aa6 | ||
|
|
1066fe8c21 | ||
|
|
9530d3a6d8 |
2
.github/workflows/build-everything-tgw.yml
vendored
2
.github/workflows/build-everything-tgw.yml
vendored
|
|
@ -67,4 +67,4 @@ jobs:
|
||||||
uses: ./.github/workflows/build-portable-release.yml
|
uses: ./.github/workflows/build-portable-release.yml
|
||||||
with:
|
with:
|
||||||
version: ${{ inputs.version }}
|
version: ${{ inputs.version }}
|
||||||
config: 'os:macos-13,macos-14'
|
config: 'os:macos-15-intel,macos-14'
|
||||||
|
|
|
||||||
|
|
@ -58,9 +58,8 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
$matrix = @{
|
$matrix = @{
|
||||||
'os' = @('ubuntu-22.04', 'windows-2022')
|
'os' = @('ubuntu-22.04', 'windows-2022')
|
||||||
'pyver' = @("3.11")
|
'pyver' = @("3.13")
|
||||||
'avx' = @("AVX2")
|
'cuda' = @("12.4", "13.1")
|
||||||
'cuda' = @("12.4")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if ($env:CONFIGIN -ne 'Default') {$env:CONFIGIN.split(';').foreach({$matrix[$_.split(':')[0]] = $_.split(':')[1].split(',')})}
|
if ($env:CONFIGIN -ne 'Default') {$env:CONFIGIN.split(';').foreach({$matrix[$_.split(':')[0]] = $_.split(':')[1].split(',')})}
|
||||||
|
|
@ -75,7 +74,7 @@ jobs:
|
||||||
Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT
|
Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT
|
||||||
|
|
||||||
build_wheels:
|
build_wheels:
|
||||||
name: ${{ matrix.os }} ${{ matrix.pyver }} CPU ${{ matrix.avx }} CUDA ${{ matrix.cuda }}
|
name: ${{ matrix.os }} ${{ matrix.pyver }} CUDA ${{ matrix.cuda }}
|
||||||
needs: define_matrix
|
needs: define_matrix
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
|
|
@ -84,17 +83,16 @@ jobs:
|
||||||
run:
|
run:
|
||||||
shell: pwsh
|
shell: pwsh
|
||||||
env:
|
env:
|
||||||
AVXVER: ${{ matrix.avx }}
|
|
||||||
PCKGVER: ${{ inputs.version }}
|
PCKGVER: ${{ inputs.version }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
repository: 'oobabooga/text-generation-webui'
|
repository: 'oobabooga/text-generation-webui'
|
||||||
ref: ${{ inputs.version }}
|
ref: ${{ inputs.version }}
|
||||||
submodules: 'recursive'
|
submodules: 'recursive'
|
||||||
|
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.pyver }}
|
python-version: ${{ matrix.pyver }}
|
||||||
|
|
||||||
|
|
@ -113,21 +111,20 @@ jobs:
|
||||||
|
|
||||||
# Define common variables
|
# Define common variables
|
||||||
CUDA_VERSION="${{ matrix.cuda }}"
|
CUDA_VERSION="${{ matrix.cuda }}"
|
||||||
AVX_SUPPORT="${{ matrix.avx }}"
|
|
||||||
VERSION="${{ inputs.version }}"
|
VERSION="${{ inputs.version }}"
|
||||||
|
|
||||||
# 1. Set platform-specific variables
|
# 1. Set platform-specific variables
|
||||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
PLATFORM="windows"
|
PLATFORM="windows"
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20250409/cpython-3.11.12+20250409-x86_64-pc-windows-msvc-install_only.tar.gz"
|
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-pc-windows-msvc-install_only.tar.gz"
|
||||||
PIP_PATH="portable_env/python.exe -m pip"
|
PIP_PATH="portable_env/python.exe -m pip"
|
||||||
PACKAGES_PATH="portable_env/Lib/site-packages"
|
PACKAGES_PATH="portable_env/Lib/site-packages"
|
||||||
rm start_linux.sh start_macos.sh
|
rm start_linux.sh start_macos.sh
|
||||||
else
|
else
|
||||||
PLATFORM="linux"
|
PLATFORM="linux"
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20250409/cpython-3.11.12+20250409-x86_64-unknown-linux-gnu-install_only.tar.gz"
|
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-unknown-linux-gnu-install_only.tar.gz"
|
||||||
PIP_PATH="portable_env/bin/python -m pip"
|
PIP_PATH="portable_env/bin/python -m pip"
|
||||||
PACKAGES_PATH="portable_env/lib/python3.11/site-packages"
|
PACKAGES_PATH="portable_env/lib/python3.13/site-packages"
|
||||||
rm start_macos.sh start_windows.bat
|
rm start_macos.sh start_windows.bat
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
@ -138,16 +135,13 @@ jobs:
|
||||||
tar -xzf python-build.tar.gz
|
tar -xzf python-build.tar.gz
|
||||||
mv python "text-generation-webui-${VERSION_CLEAN}/portable_env"
|
mv python "text-generation-webui-${VERSION_CLEAN}/portable_env"
|
||||||
|
|
||||||
# 3. Prepare requirements file based on AVX and CUDA
|
# 3. Prepare requirements file based on CUDA version
|
||||||
if [[ "$AVX_SUPPORT" == "AVX2" ]]; then
|
|
||||||
BASE_REQ_FILE="requirements/portable/requirements.txt"
|
|
||||||
else
|
|
||||||
BASE_REQ_FILE="requirements/portable/requirements_noavx2.txt"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Create CUDA-specific requirements file if needed
|
|
||||||
cd "text-generation-webui-${VERSION_CLEAN}"
|
cd "text-generation-webui-${VERSION_CLEAN}"
|
||||||
REQ_FILE="$BASE_REQ_FILE"
|
if [[ "$CUDA_VERSION" == "13.1" ]]; then
|
||||||
|
REQ_FILE="requirements/portable/requirements_cuda131.txt"
|
||||||
|
else
|
||||||
|
REQ_FILE="requirements/portable/requirements.txt"
|
||||||
|
fi
|
||||||
|
|
||||||
# 4. Install packages
|
# 4. Install packages
|
||||||
echo "Installing Python packages from $REQ_FILE..."
|
echo "Installing Python packages from $REQ_FILE..."
|
||||||
|
|
@ -156,15 +150,16 @@ jobs:
|
||||||
# 5. Clean up
|
# 5. Clean up
|
||||||
rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py
|
rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py
|
||||||
|
|
||||||
# 6. Create ZIP file
|
# 6. Create archive
|
||||||
cd ..
|
cd ..
|
||||||
ZIP_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-cuda${CUDA_VERSION}.zip"
|
|
||||||
echo "Creating archive: $ZIP_NAME"
|
|
||||||
|
|
||||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ZIP_NAME"
|
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-cuda${CUDA_VERSION}.zip"
|
||||||
|
echo "Creating archive: $ARCHIVE_NAME"
|
||||||
|
powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ARCHIVE_NAME"
|
||||||
else
|
else
|
||||||
zip -r "$ZIP_NAME" "text-generation-webui-${VERSION_CLEAN}"
|
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-cuda${CUDA_VERSION}.tar.gz"
|
||||||
|
echo "Creating archive: $ARCHIVE_NAME"
|
||||||
|
tar czf "$ARCHIVE_NAME" "text-generation-webui-${VERSION_CLEAN}"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Upload files to a GitHub release
|
- name: Upload files to a GitHub release
|
||||||
|
|
@ -173,7 +168,7 @@ jobs:
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
with:
|
with:
|
||||||
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
file: ../textgen-portable-*.zip
|
file: ../textgen-portable-*
|
||||||
tag: ${{ inputs.version }}
|
tag: ${{ inputs.version }}
|
||||||
file_glob: true
|
file_glob: true
|
||||||
make_latest: false
|
make_latest: false
|
||||||
|
|
|
||||||
|
|
@ -57,9 +57,8 @@ jobs:
|
||||||
id: set-matrix
|
id: set-matrix
|
||||||
run: |
|
run: |
|
||||||
$matrix = @{
|
$matrix = @{
|
||||||
'os' = @('ubuntu-22.04')
|
'os' = @('ubuntu-22.04', 'windows-2022')
|
||||||
'pyver' = @("3.11")
|
'pyver' = @("3.13")
|
||||||
'avx' = @("AVX2")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if ($env:CONFIGIN -ne 'Default') {$env:CONFIGIN.split(';').foreach({$matrix[$_.split(':')[0]] = $_.split(':')[1].split(',')})}
|
if ($env:CONFIGIN -ne 'Default') {$env:CONFIGIN.split(';').foreach({$matrix[$_.split(':')[0]] = $_.split(':')[1].split(',')})}
|
||||||
|
|
@ -74,7 +73,7 @@ jobs:
|
||||||
Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT
|
Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT
|
||||||
|
|
||||||
build_wheels:
|
build_wheels:
|
||||||
name: ${{ matrix.os }} ${{ matrix.pyver }} CPU ${{ matrix.avx }}
|
name: ${{ matrix.os }} ${{ matrix.pyver }}
|
||||||
needs: define_matrix
|
needs: define_matrix
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
|
|
@ -83,17 +82,16 @@ jobs:
|
||||||
run:
|
run:
|
||||||
shell: pwsh
|
shell: pwsh
|
||||||
env:
|
env:
|
||||||
AVXVER: ${{ matrix.avx }}
|
|
||||||
PCKGVER: ${{ inputs.version }}
|
PCKGVER: ${{ inputs.version }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
repository: 'oobabooga/text-generation-webui'
|
repository: 'oobabooga/text-generation-webui'
|
||||||
ref: ${{ inputs.version }}
|
ref: ${{ inputs.version }}
|
||||||
submodules: 'recursive'
|
submodules: 'recursive'
|
||||||
|
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.pyver }}
|
python-version: ${{ matrix.pyver }}
|
||||||
|
|
||||||
|
|
@ -111,15 +109,22 @@ jobs:
|
||||||
find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf
|
find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf
|
||||||
|
|
||||||
# Define common variables
|
# Define common variables
|
||||||
AVX_SUPPORT="${{ matrix.avx }}"
|
|
||||||
VERSION="${{ inputs.version }}"
|
VERSION="${{ inputs.version }}"
|
||||||
|
|
||||||
# 1. Set platform-specific variables (Linux only for ROCm)
|
# 1. Set platform-specific variables
|
||||||
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
|
PLATFORM="windows"
|
||||||
|
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-pc-windows-msvc-install_only.tar.gz"
|
||||||
|
PIP_PATH="portable_env/python.exe -m pip"
|
||||||
|
PACKAGES_PATH="portable_env/Lib/site-packages"
|
||||||
|
rm start_linux.sh start_macos.sh
|
||||||
|
else
|
||||||
PLATFORM="linux"
|
PLATFORM="linux"
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20250409/cpython-3.11.12+20250409-x86_64-unknown-linux-gnu-install_only.tar.gz"
|
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-unknown-linux-gnu-install_only.tar.gz"
|
||||||
PIP_PATH="portable_env/bin/python -m pip"
|
PIP_PATH="portable_env/bin/python -m pip"
|
||||||
PACKAGES_PATH="portable_env/lib/python3.11/site-packages"
|
PACKAGES_PATH="portable_env/lib/python3.13/site-packages"
|
||||||
rm start_macos.sh start_windows.bat
|
rm start_macos.sh start_windows.bat
|
||||||
|
fi
|
||||||
|
|
||||||
# 2. Download and extract Python
|
# 2. Download and extract Python
|
||||||
cd ..
|
cd ..
|
||||||
|
|
@ -128,13 +133,8 @@ jobs:
|
||||||
tar -xzf python-build.tar.gz
|
tar -xzf python-build.tar.gz
|
||||||
mv python "text-generation-webui-${VERSION_CLEAN}/portable_env"
|
mv python "text-generation-webui-${VERSION_CLEAN}/portable_env"
|
||||||
|
|
||||||
# 3. Prepare requirements file based on AVX
|
# 3. Prepare requirements file
|
||||||
if [[ "$AVX_SUPPORT" == "AVX2" ]]; then
|
REQ_FILE="requirements/portable/requirements_amd.txt"
|
||||||
BASE_REQ_FILE="requirements/portable/requirements_amd.txt"
|
|
||||||
else
|
|
||||||
BASE_REQ_FILE="requirements/portable/requirements_amd_noavx2.txt"
|
|
||||||
fi
|
|
||||||
REQ_FILE="$BASE_REQ_FILE"
|
|
||||||
|
|
||||||
cd "text-generation-webui-${VERSION_CLEAN}"
|
cd "text-generation-webui-${VERSION_CLEAN}"
|
||||||
|
|
||||||
|
|
@ -145,12 +145,17 @@ jobs:
|
||||||
# 5. Clean up
|
# 5. Clean up
|
||||||
rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py
|
rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py
|
||||||
|
|
||||||
# 6. Create ZIP file
|
# 6. Create archive
|
||||||
cd ..
|
cd ..
|
||||||
ZIP_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm.zip"
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
echo "Creating archive: $ZIP_NAME"
|
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm7.2.zip"
|
||||||
|
echo "Creating archive: $ARCHIVE_NAME"
|
||||||
zip -r "$ZIP_NAME" "text-generation-webui-${VERSION_CLEAN}"
|
powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ARCHIVE_NAME"
|
||||||
|
else
|
||||||
|
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm7.2.tar.gz"
|
||||||
|
echo "Creating archive: $ARCHIVE_NAME"
|
||||||
|
tar czf "$ARCHIVE_NAME" "text-generation-webui-${VERSION_CLEAN}"
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Upload files to a GitHub release
|
- name: Upload files to a GitHub release
|
||||||
id: upload-release
|
id: upload-release
|
||||||
|
|
@ -158,7 +163,7 @@ jobs:
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
with:
|
with:
|
||||||
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
file: ../textgen-portable-*.zip
|
file: ../textgen-portable-*
|
||||||
tag: ${{ inputs.version }}
|
tag: ${{ inputs.version }}
|
||||||
file_glob: true
|
file_glob: true
|
||||||
make_latest: false
|
make_latest: false
|
||||||
|
|
|
||||||
|
|
@ -58,8 +58,7 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
$matrix = @{
|
$matrix = @{
|
||||||
'os' = @('ubuntu-22.04', 'windows-2022')
|
'os' = @('ubuntu-22.04', 'windows-2022')
|
||||||
'pyver' = @("3.11")
|
'pyver' = @("3.13")
|
||||||
'avx' = @("AVX2")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if ($env:CONFIGIN -ne 'Default') {$env:CONFIGIN.split(';').foreach({$matrix[$_.split(':')[0]] = $_.split(':')[1].split(',')})}
|
if ($env:CONFIGIN -ne 'Default') {$env:CONFIGIN.split(';').foreach({$matrix[$_.split(':')[0]] = $_.split(':')[1].split(',')})}
|
||||||
|
|
@ -74,7 +73,7 @@ jobs:
|
||||||
Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT
|
Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT
|
||||||
|
|
||||||
build_wheels:
|
build_wheels:
|
||||||
name: ${{ matrix.os }} ${{ matrix.pyver }} CPU ${{ matrix.avx }}
|
name: ${{ matrix.os }} ${{ matrix.pyver }}
|
||||||
needs: define_matrix
|
needs: define_matrix
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
|
|
@ -83,17 +82,16 @@ jobs:
|
||||||
run:
|
run:
|
||||||
shell: pwsh
|
shell: pwsh
|
||||||
env:
|
env:
|
||||||
AVXVER: ${{ matrix.avx }}
|
|
||||||
PCKGVER: ${{ inputs.version }}
|
PCKGVER: ${{ inputs.version }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
repository: 'oobabooga/text-generation-webui'
|
repository: 'oobabooga/text-generation-webui'
|
||||||
ref: ${{ inputs.version }}
|
ref: ${{ inputs.version }}
|
||||||
submodules: 'recursive'
|
submodules: 'recursive'
|
||||||
|
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.pyver }}
|
python-version: ${{ matrix.pyver }}
|
||||||
|
|
||||||
|
|
@ -111,21 +109,20 @@ jobs:
|
||||||
find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf
|
find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf
|
||||||
|
|
||||||
# Define common variables
|
# Define common variables
|
||||||
AVX_SUPPORT="${{ matrix.avx }}"
|
|
||||||
VERSION="${{ inputs.version }}"
|
VERSION="${{ inputs.version }}"
|
||||||
|
|
||||||
# 1. Set platform-specific variables
|
# 1. Set platform-specific variables
|
||||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
PLATFORM="windows"
|
PLATFORM="windows"
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20250409/cpython-3.11.12+20250409-x86_64-pc-windows-msvc-install_only.tar.gz"
|
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-pc-windows-msvc-install_only.tar.gz"
|
||||||
PIP_PATH="portable_env/python.exe -m pip"
|
PIP_PATH="portable_env/python.exe -m pip"
|
||||||
PACKAGES_PATH="portable_env/Lib/site-packages"
|
PACKAGES_PATH="portable_env/Lib/site-packages"
|
||||||
rm start_linux.sh start_macos.sh
|
rm start_linux.sh start_macos.sh
|
||||||
else
|
else
|
||||||
PLATFORM="linux"
|
PLATFORM="linux"
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20250409/cpython-3.11.12+20250409-x86_64-unknown-linux-gnu-install_only.tar.gz"
|
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-unknown-linux-gnu-install_only.tar.gz"
|
||||||
PIP_PATH="portable_env/bin/python -m pip"
|
PIP_PATH="portable_env/bin/python -m pip"
|
||||||
PACKAGES_PATH="portable_env/lib/python3.11/site-packages"
|
PACKAGES_PATH="portable_env/lib/python3.13/site-packages"
|
||||||
rm start_macos.sh start_windows.bat
|
rm start_macos.sh start_windows.bat
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
@ -136,13 +133,8 @@ jobs:
|
||||||
tar -xzf python-build.tar.gz
|
tar -xzf python-build.tar.gz
|
||||||
mv python "text-generation-webui-${VERSION_CLEAN}/portable_env"
|
mv python "text-generation-webui-${VERSION_CLEAN}/portable_env"
|
||||||
|
|
||||||
# 3. Prepare requirements file based on AVX
|
# 3. Prepare requirements file
|
||||||
if [[ "$AVX_SUPPORT" == "AVX2" ]]; then
|
REQ_FILE="requirements/portable/requirements_vulkan.txt"
|
||||||
BASE_REQ_FILE="requirements/portable/requirements_vulkan.txt"
|
|
||||||
else
|
|
||||||
BASE_REQ_FILE="requirements/portable/requirements_vulkan_noavx2.txt"
|
|
||||||
fi
|
|
||||||
REQ_FILE="$BASE_REQ_FILE"
|
|
||||||
|
|
||||||
cd "text-generation-webui-${VERSION_CLEAN}"
|
cd "text-generation-webui-${VERSION_CLEAN}"
|
||||||
|
|
||||||
|
|
@ -153,15 +145,16 @@ jobs:
|
||||||
# 5. Clean up
|
# 5. Clean up
|
||||||
rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py
|
rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py
|
||||||
|
|
||||||
# 6. Create ZIP file
|
# 6. Create archive
|
||||||
cd ..
|
cd ..
|
||||||
ZIP_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-vulkan.zip"
|
|
||||||
echo "Creating archive: $ZIP_NAME"
|
|
||||||
|
|
||||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ZIP_NAME"
|
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-vulkan.zip"
|
||||||
|
echo "Creating archive: $ARCHIVE_NAME"
|
||||||
|
powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ARCHIVE_NAME"
|
||||||
else
|
else
|
||||||
zip -r "$ZIP_NAME" "text-generation-webui-${VERSION_CLEAN}"
|
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-vulkan.tar.gz"
|
||||||
|
echo "Creating archive: $ARCHIVE_NAME"
|
||||||
|
tar czf "$ARCHIVE_NAME" "text-generation-webui-${VERSION_CLEAN}"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Upload files to a GitHub release
|
- name: Upload files to a GitHub release
|
||||||
|
|
@ -170,7 +163,7 @@ jobs:
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
with:
|
with:
|
||||||
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
file: ../textgen-portable-*.zip
|
file: ../textgen-portable-*
|
||||||
tag: ${{ inputs.version }}
|
tag: ${{ inputs.version }}
|
||||||
file_glob: true
|
file_glob: true
|
||||||
make_latest: false
|
make_latest: false
|
||||||
|
|
|
||||||
49
.github/workflows/build-portable-release.yml
vendored
49
.github/workflows/build-portable-release.yml
vendored
|
|
@ -58,8 +58,7 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
$matrix = @{
|
$matrix = @{
|
||||||
'os' = @('ubuntu-22.04', 'windows-2022', 'macos-14')
|
'os' = @('ubuntu-22.04', 'windows-2022', 'macos-14')
|
||||||
'pyver' = @("3.11")
|
'pyver' = @("3.13")
|
||||||
'avx' = @("AVX2")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if ($env:CONFIGIN -ne 'Default') {$env:CONFIGIN.split(';').foreach({$matrix[$_.split(':')[0]] = $_.split(':')[1].split(',')})}
|
if ($env:CONFIGIN -ne 'Default') {$env:CONFIGIN.split(';').foreach({$matrix[$_.split(':')[0]] = $_.split(':')[1].split(',')})}
|
||||||
|
|
@ -74,7 +73,7 @@ jobs:
|
||||||
Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT
|
Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT
|
||||||
|
|
||||||
build_wheels:
|
build_wheels:
|
||||||
name: ${{ matrix.os }} ${{ matrix.pyver }} CPU ${{ matrix.avx }}
|
name: ${{ matrix.os }} ${{ matrix.pyver }}
|
||||||
needs: define_matrix
|
needs: define_matrix
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
|
|
@ -83,17 +82,16 @@ jobs:
|
||||||
run:
|
run:
|
||||||
shell: pwsh
|
shell: pwsh
|
||||||
env:
|
env:
|
||||||
AVXVER: ${{ matrix.avx }}
|
|
||||||
PCKGVER: ${{ inputs.version }}
|
PCKGVER: ${{ inputs.version }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
with:
|
with:
|
||||||
repository: 'oobabooga/text-generation-webui'
|
repository: 'oobabooga/text-generation-webui'
|
||||||
ref: ${{ inputs.version }}
|
ref: ${{ inputs.version }}
|
||||||
submodules: 'recursive'
|
submodules: 'recursive'
|
||||||
|
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.pyver }}
|
python-version: ${{ matrix.pyver }}
|
||||||
|
|
||||||
|
|
@ -111,36 +109,35 @@ jobs:
|
||||||
find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf
|
find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf
|
||||||
|
|
||||||
# Define common variables
|
# Define common variables
|
||||||
AVX_SUPPORT="${{ matrix.avx }}"
|
|
||||||
VERSION="${{ inputs.version }}"
|
VERSION="${{ inputs.version }}"
|
||||||
OS_TYPE="${{ matrix.os }}"
|
OS_TYPE="${{ matrix.os }}"
|
||||||
|
|
||||||
# 1. Set platform-specific variables
|
# 1. Set platform-specific variables
|
||||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
PLATFORM="windows-cpu"
|
PLATFORM="windows-cpu"
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20250409/cpython-3.11.12+20250409-x86_64-pc-windows-msvc-install_only.tar.gz"
|
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-pc-windows-msvc-install_only.tar.gz"
|
||||||
PIP_PATH="portable_env/python.exe -m pip"
|
PIP_PATH="portable_env/python.exe -m pip"
|
||||||
PACKAGES_PATH="portable_env/Lib/site-packages"
|
PACKAGES_PATH="portable_env/Lib/site-packages"
|
||||||
rm start_linux.sh start_macos.sh
|
rm start_linux.sh start_macos.sh
|
||||||
elif [[ "$RUNNER_OS" == "macOS" ]]; then
|
elif [[ "$RUNNER_OS" == "macOS" ]]; then
|
||||||
if [[ "$OS_TYPE" == "macos-13" ]]; then
|
if [[ "$OS_TYPE" == "macos-15-intel" ]]; then
|
||||||
PLATFORM="macos-x86_64"
|
PLATFORM="macos-x86_64"
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20250409/cpython-3.11.12+20250409-x86_64-apple-darwin-install_only.tar.gz"
|
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-apple-darwin-install_only.tar.gz"
|
||||||
REQ_TYPE="apple_intel"
|
REQ_TYPE="apple_intel"
|
||||||
else
|
else
|
||||||
PLATFORM="macos-arm64"
|
PLATFORM="macos-arm64"
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20250409/cpython-3.11.12+20250409-aarch64-apple-darwin-install_only.tar.gz"
|
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-aarch64-apple-darwin-install_only.tar.gz"
|
||||||
REQ_TYPE="apple_silicon"
|
REQ_TYPE="apple_silicon"
|
||||||
fi
|
fi
|
||||||
PIP_PATH="portable_env/bin/python -m pip"
|
PIP_PATH="portable_env/bin/python -m pip"
|
||||||
PACKAGES_PATH="portable_env/lib/python3.11/site-packages"
|
PACKAGES_PATH="portable_env/lib/python3.13/site-packages"
|
||||||
rm start_linux.sh start_windows.bat
|
rm start_linux.sh start_windows.bat
|
||||||
else
|
else
|
||||||
# Linux case
|
# Linux case
|
||||||
PLATFORM="linux-cpu"
|
PLATFORM="linux-cpu"
|
||||||
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20250409/cpython-3.11.12+20250409-x86_64-unknown-linux-gnu-install_only.tar.gz"
|
PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-unknown-linux-gnu-install_only.tar.gz"
|
||||||
PIP_PATH="portable_env/bin/python -m pip"
|
PIP_PATH="portable_env/bin/python -m pip"
|
||||||
PACKAGES_PATH="portable_env/lib/python3.11/site-packages"
|
PACKAGES_PATH="portable_env/lib/python3.13/site-packages"
|
||||||
rm start_macos.sh start_windows.bat
|
rm start_macos.sh start_windows.bat
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
@ -151,23 +148,18 @@ jobs:
|
||||||
tar -xzf python-build.tar.gz
|
tar -xzf python-build.tar.gz
|
||||||
mv python "text-generation-webui-${VERSION_CLEAN}/portable_env"
|
mv python "text-generation-webui-${VERSION_CLEAN}/portable_env"
|
||||||
|
|
||||||
# 3. Prepare requirements file based on platform and AVX
|
# 3. Prepare requirements file based on platform
|
||||||
cd "text-generation-webui-${VERSION_CLEAN}"
|
cd "text-generation-webui-${VERSION_CLEAN}"
|
||||||
|
|
||||||
# Select requirements file based on platform
|
# Select requirements file based on platform
|
||||||
if [[ "$RUNNER_OS" == "macOS" ]]; then
|
if [[ "$RUNNER_OS" == "macOS" ]]; then
|
||||||
if [[ "$OS_TYPE" == "macos-13" ]]; then
|
if [[ "$OS_TYPE" == "macos-15-intel" ]]; then
|
||||||
REQ_FILE="requirements/portable/requirements_apple_intel.txt"
|
REQ_FILE="requirements/portable/requirements_apple_intel.txt"
|
||||||
else
|
else
|
||||||
REQ_FILE="requirements/portable/requirements_apple_silicon.txt"
|
REQ_FILE="requirements/portable/requirements_apple_silicon.txt"
|
||||||
fi
|
fi
|
||||||
else
|
else
|
||||||
# For Windows and Linux, check AVX support
|
|
||||||
if [[ "$AVX_SUPPORT" == "AVX2" ]]; then
|
|
||||||
REQ_FILE="requirements/portable/requirements_cpu_only.txt"
|
REQ_FILE="requirements/portable/requirements_cpu_only.txt"
|
||||||
else
|
|
||||||
REQ_FILE="requirements/portable/requirements_cpu_only_noavx2.txt"
|
|
||||||
fi
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo "Using requirements file: $REQ_FILE"
|
echo "Using requirements file: $REQ_FILE"
|
||||||
|
|
@ -179,15 +171,16 @@ jobs:
|
||||||
# 5. Clean up
|
# 5. Clean up
|
||||||
rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py
|
rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py
|
||||||
|
|
||||||
# 6. Create ZIP file
|
# 6. Create archive
|
||||||
cd ..
|
cd ..
|
||||||
ZIP_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}.zip"
|
|
||||||
echo "Creating archive: $ZIP_NAME"
|
|
||||||
|
|
||||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ZIP_NAME"
|
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}.zip"
|
||||||
|
echo "Creating archive: $ARCHIVE_NAME"
|
||||||
|
powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ARCHIVE_NAME"
|
||||||
else
|
else
|
||||||
zip -r "$ZIP_NAME" "text-generation-webui-${VERSION_CLEAN}"
|
ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}.tar.gz"
|
||||||
|
echo "Creating archive: $ARCHIVE_NAME"
|
||||||
|
tar czf "$ARCHIVE_NAME" "text-generation-webui-${VERSION_CLEAN}"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Upload files to a GitHub release
|
- name: Upload files to a GitHub release
|
||||||
|
|
@ -196,7 +189,7 @@ jobs:
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
with:
|
with:
|
||||||
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
file: ../textgen-portable-*.zip
|
file: ../textgen-portable-*
|
||||||
tag: ${{ inputs.version }}
|
tag: ${{ inputs.version }}
|
||||||
file_glob: true
|
file_glob: true
|
||||||
make_latest: false
|
make_latest: false
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@
|
||||||
"source": [
|
"source": [
|
||||||
"#@title 2. Launch the web UI\n",
|
"#@title 2. Launch the web UI\n",
|
||||||
"\n",
|
"\n",
|
||||||
"#@markdown If unsure about the branch, write \"main\" or leave it blank.\n",
|
"#@markdown You can provide a direct GGUF link or a Hugging Face model URL.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"from pathlib import Path\n",
|
"from pathlib import Path\n",
|
||||||
|
|
@ -72,9 +72,9 @@
|
||||||
" ./start_linux.sh\n",
|
" ./start_linux.sh\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Parameters\n",
|
"# Parameters\n",
|
||||||
"model_url = \"https://huggingface.co/turboderp/gemma-2-9b-it-exl2\" #@param {type:\"string\"}\n",
|
"model_url = \"https://huggingface.co/unsloth/Qwen3.5-9B-GGUF/resolve/main/Qwen3.5-9B-Q4_K_M.gguf\" #@param {type:\"string\"}\n",
|
||||||
"branch = \"8.0bpw\" #@param {type:\"string\"}\n",
|
"branch = \"\" #@param {type:\"string\"}\n",
|
||||||
"command_line_flags = \"--n-gpu-layers 128 --load-in-4bit --use_double_quant --no_flash_attn\" #@param {type:\"string\"}\n",
|
"command_line_flags = \"--load-in-4bit --use_double_quant\" #@param {type:\"string\"}\n",
|
||||||
"api = False #@param {type:\"boolean\"}\n",
|
"api = False #@param {type:\"boolean\"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"if api:\n",
|
"if api:\n",
|
||||||
|
|
@ -83,26 +83,28 @@
|
||||||
" command_line_flags += f\" {param}\"\n",
|
" command_line_flags += f\" {param}\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model_url = model_url.strip()\n",
|
"model_url = model_url.strip()\n",
|
||||||
|
"model_name = \"\"\n",
|
||||||
"if model_url != \"\":\n",
|
"if model_url != \"\":\n",
|
||||||
" if not model_url.startswith('http'):\n",
|
" if not model_url.startswith('http'):\n",
|
||||||
" model_url = 'https://huggingface.co/' + model_url\n",
|
" model_url = 'https://huggingface.co/' + model_url\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Download the model\n",
|
" branch = branch.strip()\n",
|
||||||
" url_parts = model_url.strip('/').strip().split('/')\n",
|
" if '/resolve/' in model_url:\n",
|
||||||
" output_folder = f\"{url_parts[-2]}_{url_parts[-1]}\"\n",
|
" model_name = model_url.split('?')[0].split('/')[-1]\n",
|
||||||
" branch = branch.strip('\"\\' ')\n",
|
" !python download-model.py {model_url}\n",
|
||||||
" if branch.strip() not in ['', 'main']:\n",
|
" else:\n",
|
||||||
" output_folder += f\"_{branch}\"\n",
|
" url_parts = model_url.strip('/').split('/')\n",
|
||||||
|
" model_name = f\"{url_parts[-2]}_{url_parts[-1]}\"\n",
|
||||||
|
" if branch not in ['', 'main']:\n",
|
||||||
|
" model_name += f\"_{branch}\"\n",
|
||||||
" !python download-model.py {model_url} --branch {branch}\n",
|
" !python download-model.py {model_url} --branch {branch}\n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
" !python download-model.py {model_url}\n",
|
" !python download-model.py {model_url}\n",
|
||||||
"else:\n",
|
|
||||||
" output_folder = \"\"\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# Start the web UI\n",
|
"# Start the web UI\n",
|
||||||
"cmd = f\"./start_linux.sh {command_line_flags} --share\"\n",
|
"cmd = f\"./start_linux.sh {command_line_flags} --share\"\n",
|
||||||
"if output_folder != \"\":\n",
|
"if model_name != \"\":\n",
|
||||||
" cmd += f\" --model {output_folder}\"\n",
|
" cmd += f\" --model {model_name}\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"!$cmd"
|
"!$cmd"
|
||||||
],
|
],
|
||||||
|
|
|
||||||
257
README.md
257
README.md
|
|
@ -13,7 +13,7 @@
|
||||||
|
|
||||||
# Text Generation Web UI
|
# Text Generation Web UI
|
||||||
|
|
||||||
A Gradio web UI for Large Language Models.
|
A Gradio web UI for running Large Language Models locally. 100% private and offline. Supports text generation, vision, tool-calling, training, image generation, and more.
|
||||||
|
|
||||||
[Try the Deep Reason extension](https://oobabooga.gumroad.com/l/deep_reason)
|
[Try the Deep Reason extension](https://oobabooga.gumroad.com/l/deep_reason)
|
||||||
|
|
||||||
|
|
@ -21,29 +21,23 @@ A Gradio web UI for Large Language Models.
|
||||||
|:---:|:---:|
|
|:---:|:---:|
|
||||||
| |  |
|
| |  |
|
||||||
|
|
||||||
## 🔥 News
|
|
||||||
|
|
||||||
- The project now supports **image generation**! Including Z-Image-Turbo, 4bit/8bit quantization, `torch.compile`, and LLM-generated prompt variations ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Image-Generation-Tutorial)).
|
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- Supports multiple local text generation backends, including [llama.cpp](https://github.com/ggerganov/llama.cpp), [Transformers](https://github.com/huggingface/transformers), [ExLlamaV3](https://github.com/turboderp-org/exllamav3), [ExLlamaV2](https://github.com/turboderp-org/exllamav2), and [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) (the latter via its own [Dockerfile](https://github.com/oobabooga/text-generation-webui/blob/main/docker/TensorRT-LLM/Dockerfile)).
|
- **Multiple backends**: [llama.cpp](https://github.com/ggerganov/llama.cpp), [Transformers](https://github.com/huggingface/transformers), [ExLlamaV3](https://github.com/turboderp-org/exllamav3), and [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM). Switch between backends and models without restarting.
|
||||||
- Easy setup: Choose between **portable builds** (zero setup, just unzip and run) for GGUF models on Windows/Linux/macOS, or the one-click installer that creates a self-contained `installer_files` directory.
|
|
||||||
- 100% offline and private, with zero telemetry, external resources, or remote update requests.
|
|
||||||
- **File attachments**: Upload text files, PDF documents, and .docx documents to talk about their contents.
|
- **File attachments**: Upload text files, PDF documents, and .docx documents to talk about their contents.
|
||||||
- **Vision (multimodal models)**: Attach images to messages for visual understanding ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Multimodal-Tutorial)).
|
- **Vision (multimodal)**: Attach images to messages for visual understanding ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Multimodal-Tutorial)).
|
||||||
|
- **Tool-calling**: Models can call custom functions during chat — web search, page fetching, math, and more. Each tool is a single `.py` file, easy to create and extend ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Tool-Calling-Tutorial)).
|
||||||
|
- **OpenAI-compatible API**: Chat and Completions endpoints with tool-calling support. Use as a local drop-in replacement for the OpenAI API ([examples](https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples)).
|
||||||
|
- **Training**: Fine-tune LoRAs on multi-turn chat or raw text datasets. Supports resuming interrupted runs ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/05-%E2%80%90-Training-Tab)).
|
||||||
- **Image generation**: A dedicated tab for `diffusers` models like **Z-Image-Turbo**. Features 4-bit/8-bit quantization and a persistent gallery with metadata ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Image-Generation-Tutorial)).
|
- **Image generation**: A dedicated tab for `diffusers` models like **Z-Image-Turbo**. Features 4-bit/8-bit quantization and a persistent gallery with metadata ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Image-Generation-Tutorial)).
|
||||||
- **Web search**: Optionally search the internet with LLM-generated queries to add context to the conversation.
|
- **Easy setup**: [Portable builds](https://github.com/oobabooga/text-generation-webui/releases) (zero setup, just unzip and run) for GGUF models on Windows/Linux/macOS, or a one-click installer for the full feature set.
|
||||||
|
- 100% offline and private, with zero telemetry, external resources, or remote update requests.
|
||||||
|
- `instruct` mode for instruction-following (like ChatGPT), and `chat-instruct`/`chat` modes for talking to custom characters. Prompts are automatically formatted with Jinja2 templates.
|
||||||
|
- Edit messages, navigate between message versions, and branch conversations at any point.
|
||||||
|
- Free-form text generation in the Notebook tab without being limited to chat turns.
|
||||||
|
- Multiple sampling parameters and generation options for sophisticated text generation control.
|
||||||
- Aesthetic UI with dark and light themes.
|
- Aesthetic UI with dark and light themes.
|
||||||
- Syntax highlighting for code blocks and LaTeX rendering for mathematical expressions.
|
- Syntax highlighting for code blocks and LaTeX rendering for mathematical expressions.
|
||||||
- `instruct` mode for instruction-following (like ChatGPT), and `chat-instruct`/`chat` modes for talking to custom characters.
|
|
||||||
- Automatic prompt formatting using Jinja2 templates. You don't need to ever worry about prompt formats.
|
|
||||||
- Edit messages, navigate between message versions, and branch conversations at any point.
|
|
||||||
- Multiple sampling parameters and generation options for sophisticated text generation control.
|
|
||||||
- Switch between different models in the UI without restarting.
|
|
||||||
- Automatic GPU layers for GGUF models (on NVIDIA GPUs).
|
|
||||||
- Free-form text generation in the Notebook tab without being limited to chat turns.
|
|
||||||
- OpenAI-compatible API with Chat and Completions endpoints, including tool-calling support – see [examples](https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples).
|
|
||||||
- Extension support, with numerous built-in and user-contributed extensions available. See the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/07-%E2%80%90-Extensions) and [extensions directory](https://github.com/oobabooga/text-generation-webui-extensions) for details.
|
- Extension support, with numerous built-in and user-contributed extensions available. See the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/07-%E2%80%90-Extensions) and [extensions directory](https://github.com/oobabooga/text-generation-webui-extensions) for details.
|
||||||
|
|
||||||
## How to install
|
## How to install
|
||||||
|
|
@ -52,10 +46,11 @@ A Gradio web UI for Large Language Models.
|
||||||
|
|
||||||
No installation needed – just download, unzip and run. All dependencies included.
|
No installation needed – just download, unzip and run. All dependencies included.
|
||||||
|
|
||||||
Compatible with GGUF (llama.cpp) models on Windows, Linux, and macOS.
|
|
||||||
|
|
||||||
Download from here: **https://github.com/oobabooga/text-generation-webui/releases**
|
Download from here: **https://github.com/oobabooga/text-generation-webui/releases**
|
||||||
|
|
||||||
|
- Builds are provided for Linux, Windows, and macOS, with options for CUDA, Vulkan, ROCm, and CPU-only.
|
||||||
|
- Compatible with GGUF (llama.cpp) models.
|
||||||
|
|
||||||
#### Option 2: Manual portable install with venv
|
#### Option 2: Manual portable install with venv
|
||||||
|
|
||||||
Very fast setup that should work on any Python 3.9+:
|
Very fast setup that should work on any Python 3.9+:
|
||||||
|
|
@ -86,7 +81,7 @@ deactivate
|
||||||
|
|
||||||
#### Option 3: One-click installer
|
#### Option 3: One-click installer
|
||||||
|
|
||||||
For users who need additional backends (ExLlamaV3, Transformers) or extensions (TTS, voice input, translation, etc). Requires ~10GB disk space and downloads PyTorch.
|
For users who need additional backends (ExLlamaV3, Transformers), training, image generation, or extensions (TTS, voice input, translation, etc). Requires ~10GB disk space and downloads PyTorch.
|
||||||
|
|
||||||
1. Clone the repository, or [download its source code](https://github.com/oobabooga/text-generation-webui/archive/refs/heads/main.zip) and extract it.
|
1. Clone the repository, or [download its source code](https://github.com/oobabooga/text-generation-webui/archive/refs/heads/main.zip) and extract it.
|
||||||
2. Run the startup script for your OS: `start_windows.bat`, `start_linux.sh`, or `start_macos.sh`.
|
2. Run the startup script for your OS: `start_windows.bat`, `start_linux.sh`, or `start_macos.sh`.
|
||||||
|
|
@ -141,7 +136,7 @@ For other platforms, download from: https://github.com/conda-forge/miniforge/rel
|
||||||
#### 1. Create a new conda environment
|
#### 1. Create a new conda environment
|
||||||
|
|
||||||
```
|
```
|
||||||
conda create -n textgen python=3.11
|
conda create -n textgen python=3.13
|
||||||
conda activate textgen
|
conda activate textgen
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -149,12 +144,12 @@ conda activate textgen
|
||||||
|
|
||||||
| System | GPU | Command |
|
| System | GPU | Command |
|
||||||
|--------|---------|---------|
|
|--------|---------|---------|
|
||||||
| Linux/WSL | NVIDIA | `pip3 install torch==2.7.1 --index-url https://download.pytorch.org/whl/cu128` |
|
| Linux/WSL | NVIDIA | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128` |
|
||||||
| Linux/WSL | CPU only | `pip3 install torch==2.7.1 --index-url https://download.pytorch.org/whl/cpu` |
|
| Linux/WSL | CPU only | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cpu` |
|
||||||
| Linux | AMD | `pip3 install torch==2.7.1 --index-url https://download.pytorch.org/whl/rocm6.2.4` |
|
| Linux | AMD | `pip3 install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.9.1%2Brocm7.2.0.lw.git7e1940d4-cp313-cp313-linux_x86_64.whl` |
|
||||||
| MacOS + MPS | Any | `pip3 install torch==2.7.1` |
|
| MacOS + MPS | Any | `pip3 install torch==2.9.1` |
|
||||||
| Windows | NVIDIA | `pip3 install torch==2.7.1 --index-url https://download.pytorch.org/whl/cu128` |
|
| Windows | NVIDIA | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128` |
|
||||||
| Windows | CPU only | `pip3 install torch==2.7.1` |
|
| Windows | CPU only | `pip3 install torch==2.9.1` |
|
||||||
|
|
||||||
The up-to-date commands can be found here: https://pytorch.org/get-started/locally/.
|
The up-to-date commands can be found here: https://pytorch.org/get-started/locally/.
|
||||||
|
|
||||||
|
|
@ -174,16 +169,13 @@ pip install -r requirements/full/<requirements file according to table below>
|
||||||
|
|
||||||
Requirements file to use:
|
Requirements file to use:
|
||||||
|
|
||||||
| GPU | CPU | requirements file to use |
|
| GPU | requirements file to use |
|
||||||
|--------|---------|---------|
|
|--------|---------|
|
||||||
| NVIDIA | has AVX2 | `requirements.txt` |
|
| NVIDIA | `requirements.txt` |
|
||||||
| NVIDIA | no AVX2 | `requirements_noavx2.txt` |
|
| AMD | `requirements_amd.txt` |
|
||||||
| AMD | has AVX2 | `requirements_amd.txt` |
|
| CPU only | `requirements_cpu_only.txt` |
|
||||||
| AMD | no AVX2 | `requirements_amd_noavx2.txt` |
|
| Apple Intel | `requirements_apple_intel.txt` |
|
||||||
| CPU only | has AVX2 | `requirements_cpu_only.txt` |
|
| Apple Silicon | `requirements_apple_silicon.txt` |
|
||||||
| CPU only | no AVX2 | `requirements_cpu_only_noavx2.txt` |
|
|
||||||
| Apple | Intel | `requirements_apple_intel.txt` |
|
|
||||||
| Apple | Apple Silicon | `requirements_apple_silicon.txt` |
|
|
||||||
|
|
||||||
### Start the web UI
|
### Start the web UI
|
||||||
|
|
||||||
|
|
@ -209,7 +201,7 @@ ln -s docker/{nvidia/Dockerfile,nvidia/docker-compose.yml,.dockerignore} .
|
||||||
For AMD GPU:
|
For AMD GPU:
|
||||||
ln -s docker/{amd/Dockerfile,amd/docker-compose.yml,.dockerignore} .
|
ln -s docker/{amd/Dockerfile,amd/docker-compose.yml,.dockerignore} .
|
||||||
For Intel GPU:
|
For Intel GPU:
|
||||||
ln -s docker/{intel/Dockerfile,amd/docker-compose.yml,.dockerignore} .
|
ln -s docker/{intel/Dockerfile,intel/docker-compose.yml,.dockerignore} .
|
||||||
For CPU only
|
For CPU only
|
||||||
ln -s docker/{cpu/Dockerfile,cpu/docker-compose.yml,.dockerignore} .
|
ln -s docker/{cpu/Dockerfile,cpu/docker-compose.yml,.dockerignore} .
|
||||||
cp docker/.env.example .env
|
cp docker/.env.example .env
|
||||||
|
|
@ -244,17 +236,24 @@ List of command-line flags
|
||||||
</summary>
|
</summary>
|
||||||
|
|
||||||
```txt
|
```txt
|
||||||
usage: server.py [-h] [--multi-user] [--model MODEL] [--lora LORA [LORA ...]] [--model-dir MODEL_DIR] [--lora-dir LORA_DIR] [--model-menu] [--settings SETTINGS]
|
usage: server.py [-h] [--user-data-dir USER_DATA_DIR] [--multi-user] [--model MODEL] [--lora LORA [LORA ...]] [--model-dir MODEL_DIR] [--lora-dir LORA_DIR] [--model-menu] [--settings SETTINGS]
|
||||||
[--extensions EXTENSIONS [EXTENSIONS ...]] [--verbose] [--idle-timeout IDLE_TIMEOUT] [--loader LOADER] [--ctx-size N] [--cache-type N] [--model-draft MODEL_DRAFT]
|
[--extensions EXTENSIONS [EXTENSIONS ...]] [--verbose] [--idle-timeout IDLE_TIMEOUT] [--image-model IMAGE_MODEL] [--image-model-dir IMAGE_MODEL_DIR] [--image-dtype {bfloat16,float16}]
|
||||||
[--draft-max DRAFT_MAX] [--gpu-layers-draft GPU_LAYERS_DRAFT] [--device-draft DEVICE_DRAFT] [--ctx-size-draft CTX_SIZE_DRAFT] [--gpu-layers N] [--mmproj MMPROJ] [--streaming-llm]
|
[--image-attn-backend {flash_attention_2,sdpa}] [--image-cpu-offload] [--image-compile] [--image-quant {none,bnb-8bit,bnb-4bit,torchao-int8wo,torchao-fp4,torchao-float8wo}]
|
||||||
[--tensor-split TENSOR_SPLIT] [--row-split] [--no-mmap] [--mlock] [--no-kv-offload] [--batch-size BATCH_SIZE] [--threads THREADS] [--threads-batch THREADS_BATCH] [--numa]
|
[--loader LOADER] [--ctx-size N] [--cache-type N] [--model-draft MODEL_DRAFT] [--draft-max DRAFT_MAX] [--gpu-layers-draft GPU_LAYERS_DRAFT] [--device-draft DEVICE_DRAFT]
|
||||||
[--extra-flags EXTRA_FLAGS] [--cpu] [--cpu-memory CPU_MEMORY] [--disk] [--disk-cache-dir DISK_CACHE_DIR] [--load-in-8bit] [--bf16] [--no-cache] [--trust-remote-code]
|
[--ctx-size-draft CTX_SIZE_DRAFT] [--spec-type {none,ngram-mod,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-cache}] [--spec-ngram-size-n SPEC_NGRAM_SIZE_N]
|
||||||
[--force-safetensors] [--no_use_fast] [--attn-implementation IMPLEMENTATION] [--load-in-4bit] [--use_double_quant] [--compute_dtype COMPUTE_DTYPE] [--quant_type QUANT_TYPE]
|
[--spec-ngram-size-m SPEC_NGRAM_SIZE_M] [--spec-ngram-min-hits SPEC_NGRAM_MIN_HITS] [--gpu-layers N] [--cpu-moe] [--mmproj MMPROJ] [--streaming-llm] [--tensor-split TENSOR_SPLIT]
|
||||||
[--enable-tp] [--tp-backend TP_BACKEND] [--gpu-split GPU_SPLIT] [--autosplit] [--cfg-cache] [--no_flash_attn] [--no_xformers] [--no_sdpa] [--num_experts_per_token N] [--cpp-runner]
|
[--row-split] [--no-mmap] [--mlock] [--no-kv-offload] [--batch-size BATCH_SIZE] [--ubatch-size UBATCH_SIZE] [--threads THREADS] [--threads-batch THREADS_BATCH] [--numa]
|
||||||
[--deepspeed] [--nvme-offload-dir NVME_OFFLOAD_DIR] [--local_rank LOCAL_RANK] [--alpha_value ALPHA_VALUE] [--rope_freq_base ROPE_FREQ_BASE] [--compress_pos_emb COMPRESS_POS_EMB]
|
[--parallel PARALLEL] [--fit-target FIT_TARGET] [--extra-flags EXTRA_FLAGS] [--cpu] [--cpu-memory CPU_MEMORY] [--disk] [--disk-cache-dir DISK_CACHE_DIR] [--load-in-8bit] [--bf16]
|
||||||
[--listen] [--listen-port LISTEN_PORT] [--listen-host LISTEN_HOST] [--share] [--auto-launch] [--gradio-auth GRADIO_AUTH] [--gradio-auth-path GRADIO_AUTH_PATH]
|
[--no-cache] [--trust-remote-code] [--force-safetensors] [--no_use_fast] [--attn-implementation IMPLEMENTATION] [--load-in-4bit] [--use_double_quant] [--compute_dtype COMPUTE_DTYPE]
|
||||||
[--ssl-keyfile SSL_KEYFILE] [--ssl-certfile SSL_CERTFILE] [--subpath SUBPATH] [--old-colors] [--portable] [--api] [--public-api] [--public-api-id PUBLIC_API_ID] [--api-port API_PORT]
|
[--quant_type QUANT_TYPE] [--gpu-split GPU_SPLIT] [--enable-tp] [--tp-backend TP_BACKEND] [--cfg-cache] [--listen] [--listen-port LISTEN_PORT] [--listen-host LISTEN_HOST] [--share]
|
||||||
[--api-key API_KEY] [--admin-key ADMIN_KEY] [--api-enable-ipv6] [--api-disable-ipv4] [--nowebui]
|
[--auto-launch] [--gradio-auth GRADIO_AUTH] [--gradio-auth-path GRADIO_AUTH_PATH] [--ssl-keyfile SSL_KEYFILE] [--ssl-certfile SSL_CERTFILE] [--subpath SUBPATH] [--old-colors]
|
||||||
|
[--portable] [--api] [--public-api] [--public-api-id PUBLIC_API_ID] [--api-port API_PORT] [--api-key API_KEY] [--admin-key ADMIN_KEY] [--api-enable-ipv6] [--api-disable-ipv4]
|
||||||
|
[--nowebui] [--temperature N] [--dynatemp-low N] [--dynatemp-high N] [--dynatemp-exponent N] [--smoothing-factor N] [--smoothing-curve N] [--min-p N] [--top-p N] [--top-k N]
|
||||||
|
[--typical-p N] [--xtc-threshold N] [--xtc-probability N] [--epsilon-cutoff N] [--eta-cutoff N] [--tfs N] [--top-a N] [--top-n-sigma N] [--adaptive-target N] [--adaptive-decay N]
|
||||||
|
[--dry-multiplier N] [--dry-allowed-length N] [--dry-base N] [--repetition-penalty N] [--frequency-penalty N] [--presence-penalty N] [--encoder-repetition-penalty N]
|
||||||
|
[--no-repeat-ngram-size N] [--repetition-penalty-range N] [--penalty-alpha N] [--guidance-scale N] [--mirostat-mode N] [--mirostat-tau N] [--mirostat-eta N]
|
||||||
|
[--do-sample | --no-do-sample] [--dynamic-temperature | --no-dynamic-temperature] [--temperature-last | --no-temperature-last] [--sampler-priority N] [--dry-sequence-breakers N]
|
||||||
|
[--enable-thinking | --no-enable-thinking] [--reasoning-effort N] [--chat-template-file CHAT_TEMPLATE_FILE]
|
||||||
|
|
||||||
Text Generation Web UI
|
Text Generation Web UI
|
||||||
|
|
||||||
|
|
@ -262,7 +261,8 @@ options:
|
||||||
-h, --help show this help message and exit
|
-h, --help show this help message and exit
|
||||||
|
|
||||||
Basic settings:
|
Basic settings:
|
||||||
--multi-user Multi-user mode. Chat histories are not saved or automatically loaded. Warning: this is likely not safe for sharing publicly.
|
--user-data-dir USER_DATA_DIR Path to the user data directory. Default: auto-detected.
|
||||||
|
--multi-user Multi-user mode. Chat histories are not saved or automatically loaded. Best suited for small trusted teams.
|
||||||
--model MODEL Name of the model to load by default.
|
--model MODEL Name of the model to load by default.
|
||||||
--lora LORA [LORA ...] The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.
|
--lora LORA [LORA ...] The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.
|
||||||
--model-dir MODEL_DIR Path to directory with all the models.
|
--model-dir MODEL_DIR Path to directory with all the models.
|
||||||
|
|
@ -274,14 +274,23 @@ Basic settings:
|
||||||
--verbose Print the prompts to the terminal.
|
--verbose Print the prompts to the terminal.
|
||||||
--idle-timeout IDLE_TIMEOUT Unload model after this many minutes of inactivity. It will be automatically reloaded when you try to use it again.
|
--idle-timeout IDLE_TIMEOUT Unload model after this many minutes of inactivity. It will be automatically reloaded when you try to use it again.
|
||||||
|
|
||||||
|
Image model:
|
||||||
|
--image-model IMAGE_MODEL Name of the image model to select on startup (overrides saved setting).
|
||||||
|
--image-model-dir IMAGE_MODEL_DIR Path to directory with all the image models.
|
||||||
|
--image-dtype {bfloat16,float16} Data type for image model.
|
||||||
|
--image-attn-backend {flash_attention_2,sdpa} Attention backend for image model.
|
||||||
|
--image-cpu-offload Enable CPU offloading for image model.
|
||||||
|
--image-compile Compile the image model for faster inference.
|
||||||
|
--image-quant {none,bnb-8bit,bnb-4bit,torchao-int8wo,torchao-fp4,torchao-float8wo}
|
||||||
|
Quantization method for image model.
|
||||||
|
|
||||||
Model loader:
|
Model loader:
|
||||||
--loader LOADER Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav2_HF, ExLlamav2,
|
--loader LOADER Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav3, TensorRT-
|
||||||
TensorRT-LLM.
|
LLM.
|
||||||
|
|
||||||
Context and cache:
|
Context and cache:
|
||||||
--ctx-size N, --n_ctx N, --max_seq_len N Context size in tokens.
|
--ctx-size, --n_ctx, --max_seq_len N Context size in tokens. 0 = auto for llama.cpp (requires gpu-layers=-1), 8192 for other loaders.
|
||||||
--cache-type N, --cache_type N KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits
|
--cache-type, --cache_type N KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8).
|
||||||
separately, e.g. q4_q8).
|
|
||||||
|
|
||||||
Speculative decoding:
|
Speculative decoding:
|
||||||
--model-draft MODEL_DRAFT Path to the draft model for speculative decoding.
|
--model-draft MODEL_DRAFT Path to the draft model for speculative decoding.
|
||||||
|
|
@ -289,9 +298,15 @@ Speculative decoding:
|
||||||
--gpu-layers-draft GPU_LAYERS_DRAFT Number of layers to offload to the GPU for the draft model.
|
--gpu-layers-draft GPU_LAYERS_DRAFT Number of layers to offload to the GPU for the draft model.
|
||||||
--device-draft DEVICE_DRAFT Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1
|
--device-draft DEVICE_DRAFT Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1
|
||||||
--ctx-size-draft CTX_SIZE_DRAFT Size of the prompt context for the draft model. If 0, uses the same as the main model.
|
--ctx-size-draft CTX_SIZE_DRAFT Size of the prompt context for the draft model. If 0, uses the same as the main model.
|
||||||
|
--spec-type {none,ngram-mod,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-cache}
|
||||||
|
Draftless speculative decoding type. Recommended: ngram-mod.
|
||||||
|
--spec-ngram-size-n SPEC_NGRAM_SIZE_N N-gram lookup size for ngram speculative decoding.
|
||||||
|
--spec-ngram-size-m SPEC_NGRAM_SIZE_M Draft n-gram size for ngram speculative decoding.
|
||||||
|
--spec-ngram-min-hits SPEC_NGRAM_MIN_HITS Minimum n-gram hits for ngram-map speculative decoding.
|
||||||
|
|
||||||
llama.cpp:
|
llama.cpp:
|
||||||
--gpu-layers N, --n-gpu-layers N Number of layers to offload to the GPU.
|
--gpu-layers, --n-gpu-layers N Number of layers to offload to the GPU. -1 = auto.
|
||||||
|
--cpu-moe Move the experts to the CPU (for MoE models).
|
||||||
--mmproj MMPROJ Path to the mmproj file for vision models.
|
--mmproj MMPROJ Path to the mmproj file for vision models.
|
||||||
--streaming-llm Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.
|
--streaming-llm Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.
|
||||||
--tensor-split TENSOR_SPLIT Split the model across multiple GPUs. Comma-separated list of proportions. Example: 60,40.
|
--tensor-split TENSOR_SPLIT Split the model across multiple GPUs. Comma-separated list of proportions. Example: 60,40.
|
||||||
|
|
@ -299,17 +314,22 @@ llama.cpp:
|
||||||
--no-mmap Prevent mmap from being used.
|
--no-mmap Prevent mmap from being used.
|
||||||
--mlock Force the system to keep the model in RAM.
|
--mlock Force the system to keep the model in RAM.
|
||||||
--no-kv-offload Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.
|
--no-kv-offload Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.
|
||||||
--batch-size BATCH_SIZE Maximum number of prompt tokens to batch together when calling llama_eval.
|
--batch-size BATCH_SIZE Maximum number of prompt tokens to batch together when calling llama-server. This is the application level batch size.
|
||||||
|
--ubatch-size UBATCH_SIZE Maximum number of prompt tokens to batch together when calling llama-server. This is the max physical batch size for computation (device level).
|
||||||
--threads THREADS Number of threads to use.
|
--threads THREADS Number of threads to use.
|
||||||
--threads-batch THREADS_BATCH Number of threads to use for batches/prompt processing.
|
--threads-batch THREADS_BATCH Number of threads to use for batches/prompt processing.
|
||||||
--numa Activate NUMA task allocation for llama.cpp.
|
--numa Activate NUMA task allocation for llama.cpp.
|
||||||
|
--parallel PARALLEL Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set
|
||||||
|
ctx_size to 32768.
|
||||||
|
--fit-target FIT_TARGET Target VRAM margin per device for auto GPU layers, comma-separated list of values in MiB. A single value is broadcast across all devices.
|
||||||
|
Default: 1024.
|
||||||
--extra-flags EXTRA_FLAGS Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"
|
--extra-flags EXTRA_FLAGS Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"
|
||||||
|
|
||||||
Transformers/Accelerate:
|
Transformers/Accelerate:
|
||||||
--cpu Use the CPU to generate text. Warning: Training on CPU is extremely slow.
|
--cpu Use the CPU to generate text. Warning: Training on CPU is extremely slow.
|
||||||
--cpu-memory CPU_MEMORY Maximum CPU memory in GiB. Use this for CPU offloading.
|
--cpu-memory CPU_MEMORY Maximum CPU memory in GiB. Use this for CPU offloading.
|
||||||
--disk If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.
|
--disk If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.
|
||||||
--disk-cache-dir DISK_CACHE_DIR Directory to save the disk cache to. Defaults to "user_data/cache".
|
--disk-cache-dir DISK_CACHE_DIR Directory to save the disk cache to.
|
||||||
--load-in-8bit Load the model with 8-bit precision (using bitsandbytes).
|
--load-in-8bit Load the model with 8-bit precision (using bitsandbytes).
|
||||||
--bf16 Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
|
--bf16 Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
|
||||||
--no-cache Set use_cache to False while generating text. This reduces VRAM usage slightly, but it comes at a performance cost.
|
--no-cache Set use_cache to False while generating text. This reduces VRAM usage slightly, but it comes at a performance cost.
|
||||||
|
|
@ -325,30 +345,10 @@ bitsandbytes 4-bit:
|
||||||
--quant_type QUANT_TYPE quant_type for 4-bit. Valid options: nf4, fp4.
|
--quant_type QUANT_TYPE quant_type for 4-bit. Valid options: nf4, fp4.
|
||||||
|
|
||||||
ExLlamaV3:
|
ExLlamaV3:
|
||||||
|
--gpu-split GPU_SPLIT Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.
|
||||||
--enable-tp, --enable_tp Enable Tensor Parallelism (TP) to split the model across GPUs.
|
--enable-tp, --enable_tp Enable Tensor Parallelism (TP) to split the model across GPUs.
|
||||||
--tp-backend TP_BACKEND The backend for tensor parallelism. Valid options: native, nccl. Default: native.
|
--tp-backend TP_BACKEND The backend for tensor parallelism. Valid options: native, nccl. Default: native.
|
||||||
|
--cfg-cache Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.
|
||||||
ExLlamaV2:
|
|
||||||
--gpu-split GPU_SPLIT Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.
|
|
||||||
--autosplit Autosplit the model tensors across the available GPUs. This causes --gpu-split to be ignored.
|
|
||||||
--cfg-cache ExLlamav2_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.
|
|
||||||
--no_flash_attn Force flash-attention to not be used.
|
|
||||||
--no_xformers Force xformers to not be used.
|
|
||||||
--no_sdpa Force Torch SDPA to not be used.
|
|
||||||
--num_experts_per_token N Number of experts to use for generation. Applies to MoE models like Mixtral.
|
|
||||||
|
|
||||||
TensorRT-LLM:
|
|
||||||
--cpp-runner Use the ModelRunnerCpp runner, which is faster than the default ModelRunner but doesn't support streaming yet.
|
|
||||||
|
|
||||||
DeepSpeed:
|
|
||||||
--deepspeed Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.
|
|
||||||
--nvme-offload-dir NVME_OFFLOAD_DIR DeepSpeed: Directory to use for ZeRO-3 NVME offloading.
|
|
||||||
--local_rank LOCAL_RANK DeepSpeed: Optional argument for distributed setups.
|
|
||||||
|
|
||||||
RoPE:
|
|
||||||
--alpha_value ALPHA_VALUE Positional embeddings alpha factor for NTK RoPE scaling. Use either this or compress_pos_emb, not both.
|
|
||||||
--rope_freq_base ROPE_FREQ_BASE If greater than 0, will be used instead of alpha_value. Those two are related by rope_freq_base = 10000 * alpha_value ^ (64 / 63).
|
|
||||||
--compress_pos_emb COMPRESS_POS_EMB Positional embeddings compression factor. Should be set to (context length) / (model's original context length). Equal to 1/rope_freq_scale.
|
|
||||||
|
|
||||||
Gradio:
|
Gradio:
|
||||||
--listen Make the web UI reachable from your local network.
|
--listen Make the web UI reachable from your local network.
|
||||||
|
|
@ -366,7 +366,7 @@ Gradio:
|
||||||
|
|
||||||
API:
|
API:
|
||||||
--api Enable the API extension.
|
--api Enable the API extension.
|
||||||
--public-api Create a public URL for the API using Cloudfare.
|
--public-api Create a public URL for the API using Cloudflare.
|
||||||
--public-api-id PUBLIC_API_ID Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.
|
--public-api-id PUBLIC_API_ID Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.
|
||||||
--api-port API_PORT The listening port for the API.
|
--api-port API_PORT The listening port for the API.
|
||||||
--api-key API_KEY API authentication key.
|
--api-key API_KEY API authentication key.
|
||||||
|
|
@ -374,65 +374,88 @@ API:
|
||||||
--api-enable-ipv6 Enable IPv6 for the API
|
--api-enable-ipv6 Enable IPv6 for the API
|
||||||
--api-disable-ipv4 Disable IPv4 for the API
|
--api-disable-ipv4 Disable IPv4 for the API
|
||||||
--nowebui Do not launch the Gradio UI. Useful for launching the API in standalone mode.
|
--nowebui Do not launch the Gradio UI. Useful for launching the API in standalone mode.
|
||||||
|
|
||||||
|
API generation defaults:
|
||||||
|
--temperature N Temperature
|
||||||
|
--dynatemp-low N Dynamic temperature low
|
||||||
|
--dynatemp-high N Dynamic temperature high
|
||||||
|
--dynatemp-exponent N Dynamic temperature exponent
|
||||||
|
--smoothing-factor N Smoothing factor
|
||||||
|
--smoothing-curve N Smoothing curve
|
||||||
|
--min-p N Min P
|
||||||
|
--top-p N Top P
|
||||||
|
--top-k N Top K
|
||||||
|
--typical-p N Typical P
|
||||||
|
--xtc-threshold N XTC threshold
|
||||||
|
--xtc-probability N XTC probability
|
||||||
|
--epsilon-cutoff N Epsilon cutoff
|
||||||
|
--eta-cutoff N Eta cutoff
|
||||||
|
--tfs N TFS
|
||||||
|
--top-a N Top A
|
||||||
|
--top-n-sigma N Top N Sigma
|
||||||
|
--adaptive-target N Adaptive target
|
||||||
|
--adaptive-decay N Adaptive decay
|
||||||
|
--dry-multiplier N DRY multiplier
|
||||||
|
--dry-allowed-length N DRY allowed length
|
||||||
|
--dry-base N DRY base
|
||||||
|
--repetition-penalty N Repetition penalty
|
||||||
|
--frequency-penalty N Frequency penalty
|
||||||
|
--presence-penalty N Presence penalty
|
||||||
|
--encoder-repetition-penalty N Encoder repetition penalty
|
||||||
|
--no-repeat-ngram-size N No repeat ngram size
|
||||||
|
--repetition-penalty-range N Repetition penalty range
|
||||||
|
--penalty-alpha N Penalty alpha
|
||||||
|
--guidance-scale N Guidance scale
|
||||||
|
--mirostat-mode N Mirostat mode
|
||||||
|
--mirostat-tau N Mirostat tau
|
||||||
|
--mirostat-eta N Mirostat eta
|
||||||
|
--do-sample, --no-do-sample Do sample
|
||||||
|
--dynamic-temperature, --no-dynamic-temperature Dynamic temperature
|
||||||
|
--temperature-last, --no-temperature-last Temperature last
|
||||||
|
--sampler-priority N Sampler priority
|
||||||
|
--dry-sequence-breakers N DRY sequence breakers
|
||||||
|
--enable-thinking, --no-enable-thinking Enable thinking
|
||||||
|
--reasoning-effort N Reasoning effort
|
||||||
|
--chat-template-file CHAT_TEMPLATE_FILE Path to a chat template file (.jinja, .jinja2, or .yaml) to use as the default instruction template for API requests. Overrides the model's
|
||||||
|
built-in template.
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
## Downloading models
|
## Downloading models
|
||||||
|
|
||||||
Models should be placed in the folder `text-generation-webui/user_data/models`. They are usually downloaded from [Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads&search=gguf).
|
1. Download a GGUF model file from [Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads&search=gguf).
|
||||||
|
2. Place it in the `user_data/models` folder.
|
||||||
|
|
||||||
To check if a GGUF model will fit in your hardware before downloading it, you can use this tool I created:
|
That's it. The UI will detect it automatically.
|
||||||
|
|
||||||
[Accurate GGUF VRAM Calculator](https://huggingface.co/spaces/oobabooga/accurate-gguf-vram-calculator)
|
To check what will fit your GPU, you can use the [VRAM Calculator](https://huggingface.co/spaces/oobabooga/accurate-gguf-vram-calculator).
|
||||||
|
|
||||||
* GGUF models are a single file and should be placed directly into `user_data/models`. Example:
|
<details>
|
||||||
|
<summary>Other model types (Transformers, EXL3)</summary>
|
||||||
|
|
||||||
|
Models that consist of multiple files (like 16-bit Transformers models and EXL3 models) should be placed in a subfolder inside `user_data/models`:
|
||||||
|
|
||||||
```
|
```
|
||||||
text-generation-webui
|
text-generation-webui
|
||||||
└── user_data
|
└── user_data
|
||||||
└── models
|
└── models
|
||||||
└── llama-2-13b-chat.Q4_K_M.gguf
|
└── Qwen_Qwen3-8B
|
||||||
```
|
|
||||||
|
|
||||||
* The remaining model types (like 16-bit Transformers models and EXL3 models) are made of several files and must be placed in a subfolder. Example:
|
|
||||||
|
|
||||||
```
|
|
||||||
text-generation-webui
|
|
||||||
└── user_data
|
|
||||||
└── models
|
|
||||||
└── lmsys_vicuna-33b-v1.3
|
|
||||||
├── config.json
|
├── config.json
|
||||||
├── generation_config.json
|
├── generation_config.json
|
||||||
├── pytorch_model-00001-of-00007.bin
|
├── model-00001-of-00004.safetensors
|
||||||
├── pytorch_model-00002-of-00007.bin
|
├── ...
|
||||||
├── pytorch_model-00003-of-00007.bin
|
|
||||||
├── pytorch_model-00004-of-00007.bin
|
|
||||||
├── pytorch_model-00005-of-00007.bin
|
|
||||||
├── pytorch_model-00006-of-00007.bin
|
|
||||||
├── pytorch_model-00007-of-00007.bin
|
|
||||||
├── pytorch_model.bin.index.json
|
|
||||||
├── special_tokens_map.json
|
|
||||||
├── tokenizer_config.json
|
├── tokenizer_config.json
|
||||||
└── tokenizer.model
|
└── tokenizer.json
|
||||||
```
|
```
|
||||||
|
|
||||||
In both cases, you can use the "Model" tab of the UI to download the model from Hugging Face automatically. It is also possible to download it via the command-line with:
|
These formats require the one-click installer (not the portable build).
|
||||||
|
</details>
|
||||||
```
|
|
||||||
python download-model.py organization/model
|
|
||||||
```
|
|
||||||
|
|
||||||
Run `python download-model.py --help` to see all the options.
|
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
https://github.com/oobabooga/text-generation-webui/wiki
|
https://github.com/oobabooga/text-generation-webui/wiki
|
||||||
|
|
||||||
## Google Colab notebook
|
|
||||||
|
|
||||||
https://colab.research.google.com/github/oobabooga/text-generation-webui/blob/main/Colab-TextGen-GPU.ipynb
|
|
||||||
|
|
||||||
## Community
|
## Community
|
||||||
|
|
||||||
https://www.reddit.com/r/Oobabooga/
|
https://www.reddit.com/r/Oobabooga/
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ set INSTALL_ENV_DIR=%cd%\installer_files\env
|
||||||
set PYTHONNOUSERSITE=1
|
set PYTHONNOUSERSITE=1
|
||||||
set PYTHONPATH=
|
set PYTHONPATH=
|
||||||
set PYTHONHOME=
|
set PYTHONHOME=
|
||||||
|
set PYTHONUTF8=1
|
||||||
set "CUDA_PATH=%INSTALL_ENV_DIR%"
|
set "CUDA_PATH=%INSTALL_ENV_DIR%"
|
||||||
set "CUDA_HOME=%CUDA_PATH%"
|
set "CUDA_HOME=%CUDA_PATH%"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
display: grid;
|
display: grid;
|
||||||
align-items: start;
|
align-items: start;
|
||||||
grid-template-columns: 60px minmax(0, 1fr);
|
grid-template-columns: 60px minmax(0, 1fr);
|
||||||
|
width: min(100%, calc(724px + 60px));
|
||||||
padding-bottom: 22px;
|
padding-bottom: 22px;
|
||||||
padding-top: 6px;
|
padding-top: 6px;
|
||||||
font-size: 18px;
|
font-size: 18px;
|
||||||
|
|
@ -91,9 +92,6 @@
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body p {
|
.message-body p {
|
||||||
margin-bottom: 0 !important;
|
|
||||||
font-size: 16px !important;
|
|
||||||
line-height: 1.5 !important;
|
|
||||||
color: #e0e0e0 !important; /* Light color for text */
|
color: #e0e0e0 !important; /* Light color for text */
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -122,7 +120,7 @@
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body p {
|
.message-body p {
|
||||||
font-size: 14px !important; /* Smaller text for mobile */
|
font-size: 14px !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.username {
|
.username {
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
display: grid;
|
display: grid;
|
||||||
align-items: start;
|
align-items: start;
|
||||||
grid-template-columns: 60px minmax(0, 1fr);
|
grid-template-columns: 60px minmax(0, 1fr);
|
||||||
|
width: min(100%, calc(724px + 60px + 90px));
|
||||||
padding-bottom: 21px;
|
padding-bottom: 21px;
|
||||||
padding-top: 7px;
|
padding-top: 7px;
|
||||||
font-size: 18px;
|
font-size: 18px;
|
||||||
|
|
@ -86,10 +87,8 @@
|
||||||
border-radius: 20px;
|
border-radius: 20px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body p {
|
.message-body p, .message-body li {
|
||||||
margin-bottom: 0 !important;
|
|
||||||
font-size: 18px !important;
|
font-size: 18px !important;
|
||||||
line-height: 1.428571429 !important;
|
|
||||||
color: rgb(243 244 246) !important;
|
color: rgb(243 244 246) !important;
|
||||||
text-shadow: 2px 2px 2px rgb(0 0 0);
|
text-shadow: 2px 2px 2px rgb(0 0 0);
|
||||||
font-weight: 500;
|
font-weight: 500;
|
||||||
|
|
@ -127,7 +126,7 @@
|
||||||
padding-left: 0;
|
padding-left: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body p {
|
.message-body p, .message-body li {
|
||||||
font-size: 16px !important;
|
font-size: 16px !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,4 +19,5 @@
|
||||||
padding-bottom: 1.5em;
|
padding-bottom: 1.5em;
|
||||||
padding-top: 0.5em;
|
padding-top: 0.5em;
|
||||||
grid-template-columns: 70px minmax(0, 1fr);
|
grid-template-columns: 70px minmax(0, 1fr);
|
||||||
|
width: min(100%, calc(724px + 70px));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
display: grid;
|
display: grid;
|
||||||
align-items: start;
|
align-items: start;
|
||||||
grid-template-columns: 60px minmax(0, 1fr);
|
grid-template-columns: 60px minmax(0, 1fr);
|
||||||
|
width: min(100%, calc(724px + 60px));
|
||||||
padding-bottom: 1.5em;
|
padding-bottom: 1.5em;
|
||||||
padding-top: 0.5em;
|
padding-top: 0.5em;
|
||||||
font-size: 15px;
|
font-size: 15px;
|
||||||
|
|
@ -46,16 +47,10 @@
|
||||||
border-radius: 20px;
|
border-radius: 20px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body p {
|
.message-body p, .message-body li {
|
||||||
font-size: 15px !important;
|
|
||||||
line-height: 22.5px !important;
|
|
||||||
font-weight: 500;
|
font-weight: 500;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body p, .chat .message-body ul, .chat .message-body ol {
|
|
||||||
margin-bottom: 10px !important;
|
|
||||||
}
|
|
||||||
|
|
||||||
.dark .message-body p em {
|
.dark .message-body p em {
|
||||||
color: rgb(138 138 138) !important;
|
color: rgb(138 138 138) !important;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
.message {
|
.message {
|
||||||
|
width: min(100%, calc(724px + 60px));
|
||||||
padding-bottom: 22px;
|
padding-bottom: 22px;
|
||||||
padding-top: 3px;
|
padding-top: 3px;
|
||||||
font-size: 15px;
|
font-size: 15px;
|
||||||
|
|
@ -60,8 +61,10 @@
|
||||||
text-align: right;
|
text-align: right;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark .circle-bot + .text div, .dark .circle-bot + .text * {
|
.dark .circle-bot + .text div, .dark .circle-bot + .text *,
|
||||||
color: #000;
|
.dark .chat .message .circle-bot + .text .message-body :is(h1, h2, h3, h4, h5, h6),
|
||||||
|
.dark .chat .message .circle-bot + .text .message-body a {
|
||||||
|
color: #000 !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.text {
|
.text {
|
||||||
|
|
@ -76,19 +79,14 @@
|
||||||
font-weight: bold;
|
font-weight: bold;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body {
|
|
||||||
}
|
|
||||||
|
|
||||||
.message-body img {
|
.message-body img {
|
||||||
max-width: 300px;
|
max-width: 300px;
|
||||||
max-height: 300px;
|
max-height: 300px;
|
||||||
border-radius: 20px;
|
border-radius: 20px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body p {
|
.message-body p, .message-body li {
|
||||||
margin-bottom: 0 !important;
|
|
||||||
font-size: 15px !important;
|
font-size: 15px !important;
|
||||||
line-height: 1.428571429 !important;
|
|
||||||
font-weight: 500;
|
font-weight: 500;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
.message {
|
.message {
|
||||||
display: block;
|
display: block;
|
||||||
|
width: min(100%, 724px);
|
||||||
padding-top: 0;
|
padding-top: 0;
|
||||||
padding-bottom: 21px;
|
padding-bottom: 21px;
|
||||||
font-size: 15px;
|
font-size: 15px;
|
||||||
|
|
@ -77,14 +78,8 @@
|
||||||
border-radius: 12px;
|
border-radius: 12px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body p {
|
.message-body p, .message-body li {
|
||||||
font-size: 15px !important;
|
font-size: 15px !important;
|
||||||
line-height: 1.4 !important;
|
|
||||||
font-weight: 400;
|
|
||||||
}
|
|
||||||
|
|
||||||
.message-body p:first-child {
|
|
||||||
margin-top: 0 !important;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark .message-body p em {
|
.dark .message-body p em {
|
||||||
|
|
@ -100,6 +95,3 @@
|
||||||
margin-top: 8px;
|
margin-top: 8px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body p, .chat .message-body ul, .chat .message-body ol {
|
|
||||||
margin-bottom: 10px !important;
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -19,12 +19,14 @@
|
||||||
color: #d1d5db !important;
|
color: #d1d5db !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.chat .message-body :is(th, td) {
|
.chat .message-body :is(th, td),
|
||||||
|
.prose hr {
|
||||||
border-color: #40404096 !important;
|
border-color: #40404096 !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark .chat .message-body :is(th, td) {
|
.dark .chat .message-body :is(th, td),
|
||||||
border-color: #ffffff75 !important;
|
.dark .prose hr {
|
||||||
|
border-color: rgb(255 255 255 / 30%) !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.chat .message-body :is(p, ul, ol) {
|
.chat .message-body :is(p, ul, ol) {
|
||||||
|
|
@ -76,7 +78,7 @@
|
||||||
|
|
||||||
.chat .user-message .text,
|
.chat .user-message .text,
|
||||||
.chat .assistant-message .text {
|
.chat .assistant-message .text {
|
||||||
max-width: 700px;
|
max-width: 724px;
|
||||||
margin-left: auto;
|
margin-left: auto;
|
||||||
margin-right: auto;
|
margin-right: auto;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
186
css/main.css
186
css/main.css
|
|
@ -400,7 +400,6 @@ audio {
|
||||||
}
|
}
|
||||||
|
|
||||||
.chat .message {
|
.chat .message {
|
||||||
width: min(100%, 48rem);
|
|
||||||
margin-left: auto;
|
margin-left: auto;
|
||||||
margin-right: auto;
|
margin-right: auto;
|
||||||
text-align: start;
|
text-align: start;
|
||||||
|
|
@ -431,10 +430,19 @@ audio {
|
||||||
font-size: 16px;
|
font-size: 16px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark .message-body :is(h1, h2, h3, h4, h5, h6) {
|
.dark .message-body h1,
|
||||||
|
.dark .message-body h2,
|
||||||
|
.dark .message-body h3,
|
||||||
|
.dark .message-body h4,
|
||||||
|
.dark .message-body h5,
|
||||||
|
.dark .message-body h6 {
|
||||||
color: white !important;
|
color: white !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.dark .message-body blockquote {
|
||||||
|
border-left-color: rgb(255 255 255 / 30%);
|
||||||
|
}
|
||||||
|
|
||||||
.message-body h1 {
|
.message-body h1 {
|
||||||
font-weight: 800;
|
font-weight: 800;
|
||||||
font-size: 2.25em;
|
font-size: 2.25em;
|
||||||
|
|
@ -715,7 +723,7 @@ audio {
|
||||||
.hover-menu {
|
.hover-menu {
|
||||||
display: none;
|
display: none;
|
||||||
position: absolute;
|
position: absolute;
|
||||||
bottom: 80%;
|
bottom: 100%;
|
||||||
left: 0;
|
left: 0;
|
||||||
box-shadow: 0 0 5px rgb(0 0 0 / 25%);
|
box-shadow: 0 0 5px rgb(0 0 0 / 25%);
|
||||||
z-index: 10000;
|
z-index: 10000;
|
||||||
|
|
@ -831,9 +839,20 @@ audio {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body ol, .message-body ul {
|
.message-body p, .message-body li {
|
||||||
|
line-height: 1.75 !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-body p, .message-body ul, .message-body ol {
|
||||||
|
margin: 1.25em 0 !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-body :is(p, ul, ol):first-child {
|
||||||
margin-top: 0 !important;
|
margin-top: 0 !important;
|
||||||
margin-bottom: 1.25em !important;
|
}
|
||||||
|
|
||||||
|
.message-body :is(p, ul, ol):last-child {
|
||||||
|
margin-bottom: 0 !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ----------------------------------------------
|
/* ----------------------------------------------
|
||||||
|
|
@ -1003,6 +1022,49 @@ audio {
|
||||||
padding-right: 0.5rem;
|
padding-right: 0.5rem;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#new-chat-wrapper {
|
||||||
|
display: contents;
|
||||||
|
}
|
||||||
|
|
||||||
|
.new-chat-arrow {
|
||||||
|
cursor: pointer;
|
||||||
|
position: relative;
|
||||||
|
padding: 0;
|
||||||
|
margin-right: -15px;
|
||||||
|
height: 39.594px;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.new-chat-menu {
|
||||||
|
display: none;
|
||||||
|
position: absolute;
|
||||||
|
top: 0;
|
||||||
|
left: 0;
|
||||||
|
padding-top: 1.2em;
|
||||||
|
z-index: var(--layer-top);
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.new-chat-arrow:hover .new-chat-menu {
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
|
||||||
|
.new-chat-menu-item {
|
||||||
|
cursor: pointer;
|
||||||
|
padding: var(--size-2);
|
||||||
|
background: var(--background-fill-primary);
|
||||||
|
box-shadow: var(--shadow-drop-lg);
|
||||||
|
border-radius: var(--container-radius);
|
||||||
|
color: var(--body-text-color);
|
||||||
|
font-size: var(--text-md);
|
||||||
|
font-weight: var(--button-large-text-weight);
|
||||||
|
}
|
||||||
|
|
||||||
|
.new-chat-menu-item:hover {
|
||||||
|
background: var(--background-fill-secondary);
|
||||||
|
}
|
||||||
|
|
||||||
#past-chats-row,
|
#past-chats-row,
|
||||||
#chat-controls {
|
#chat-controls {
|
||||||
width: 260px;
|
width: 260px;
|
||||||
|
|
@ -1373,7 +1435,6 @@ audio {
|
||||||
overflow-wrap: break-word;
|
overflow-wrap: break-word;
|
||||||
max-height: 250px;
|
max-height: 250px;
|
||||||
overflow-y: scroll;
|
overflow-y: scroll;
|
||||||
contain: layout;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.chat .message-body .thinking-content p,
|
.chat .message-body .thinking-content p,
|
||||||
|
|
@ -1645,7 +1706,7 @@ button:focus {
|
||||||
}
|
}
|
||||||
|
|
||||||
#user-description textarea {
|
#user-description textarea {
|
||||||
height: calc(100vh - 231px) !important;
|
height: calc(100vh - 334px) !important;
|
||||||
min-height: 90px !important;
|
min-height: 90px !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1662,7 +1723,7 @@ button:focus {
|
||||||
.chat-parent {
|
.chat-parent {
|
||||||
/* Optimize for scrolling performance */
|
/* Optimize for scrolling performance */
|
||||||
will-change: scroll-position;
|
will-change: scroll-position;
|
||||||
contain: layout style paint;
|
contain: style paint;
|
||||||
|
|
||||||
/* Ensure GPU acceleration */
|
/* Ensure GPU acceleration */
|
||||||
transform: translateZ(0);
|
transform: translateZ(0);
|
||||||
|
|
@ -1797,3 +1858,112 @@ button#swap-height-width {
|
||||||
top: 0;
|
top: 0;
|
||||||
left: calc(100% - 174px);
|
left: calc(100% - 174px);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
table {
|
||||||
|
border-collapse: collapse;
|
||||||
|
}
|
||||||
|
|
||||||
|
.table-wrapper {
|
||||||
|
overflow-x: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-body :is(td, th) {
|
||||||
|
word-break: normal;
|
||||||
|
overflow-wrap: normal;
|
||||||
|
}
|
||||||
|
|
||||||
|
table, tr, td, th, thead {
|
||||||
|
border: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
td + td,
|
||||||
|
th + th { border-left: 1px solid; }
|
||||||
|
|
||||||
|
tr + tr td,
|
||||||
|
tr + tr th { border-top: 1px solid; }
|
||||||
|
|
||||||
|
thead + tbody tr:first-child td,
|
||||||
|
thead + tbody tr:first-child th { border-top: 1px solid; }
|
||||||
|
|
||||||
|
/* ------------------------------------------------
|
||||||
|
Tools CheckboxGroup - vertical DragDrop-like style
|
||||||
|
------------------------------------------------ */
|
||||||
|
|
||||||
|
/* "Refresh list" link in the Tools label */
|
||||||
|
.tools-refresh-link {
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Checkbox list container */
|
||||||
|
#tools-group {
|
||||||
|
padding: 0 !important;
|
||||||
|
border-width: 0 !important;
|
||||||
|
background: transparent !important;
|
||||||
|
min-height: 0 !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
#tools-group .wrap {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
flex-wrap: nowrap;
|
||||||
|
gap: 4px;
|
||||||
|
padding: 0;
|
||||||
|
margin-top: var(--spacing-lg);
|
||||||
|
max-height: 350px;
|
||||||
|
overflow-y: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Pretty scrollbar for the tools list */
|
||||||
|
#tools-group .wrap::-webkit-scrollbar {
|
||||||
|
width: 8px;
|
||||||
|
height: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#tools-group .wrap::-webkit-scrollbar-track {
|
||||||
|
background: transparent;
|
||||||
|
}
|
||||||
|
|
||||||
|
#tools-group .wrap::-webkit-scrollbar-thumb,
|
||||||
|
#tools-group .wrap::-webkit-scrollbar-thumb:hover {
|
||||||
|
background: var(--neutral-300);
|
||||||
|
border-radius: 30px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.dark #tools-group .wrap::-webkit-scrollbar-thumb,
|
||||||
|
.dark #tools-group .wrap::-webkit-scrollbar-thumb:hover {
|
||||||
|
background: rgb(255 255 255 / 6.25%);
|
||||||
|
border-radius: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#tools-group .wrap::-webkit-scrollbar-corner {
|
||||||
|
background: transparent;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Each checkbox item */
|
||||||
|
#tools-group label {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 8px;
|
||||||
|
padding: 5px 8px;
|
||||||
|
border-radius: var(--radius-sm, 4px);
|
||||||
|
background: var(--block-background-fill);
|
||||||
|
border: 1px solid var(--border-color-primary);
|
||||||
|
color: var(--body-text-color);
|
||||||
|
font-size: var(--input-text-size);
|
||||||
|
font-weight: var(--input-text-weight);
|
||||||
|
cursor: pointer;
|
||||||
|
user-select: none;
|
||||||
|
transition: border-color 0.15s ease, background 0.15s ease;
|
||||||
|
box-shadow: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
#tools-group label:hover {
|
||||||
|
border-color: var(--input-border-color-focus);
|
||||||
|
}
|
||||||
|
|
||||||
|
#tools-group label span {
|
||||||
|
flex: 1;
|
||||||
|
overflow: hidden;
|
||||||
|
text-overflow: ellipsis;
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,3 @@
|
||||||
.env
|
.env
|
||||||
Dockerfile
|
Dockerfile
|
||||||
/characters
|
/user_data
|
||||||
/loras
|
|
||||||
/models
|
|
||||||
/presets
|
|
||||||
/prompts
|
|
||||||
/training
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
# by default the Dockerfile specifies these versions: 3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX
|
# specify which cuda arch version your card supports (NVIDIA only)
|
||||||
# however for me to work i had to specify the exact version for my card ( 2060 ) it was 7.5
|
# https://developer.nvidia.com/cuda-gpus
|
||||||
# https://developer.nvidia.com/cuda-gpus you can find the version for your card here
|
# or run: nvidia-smi --query-gpu=name,compute_cap --format=csv
|
||||||
# Or for a programatic approach run `nvidia-smi --query-gpu=name,compute_cap --format=csv`
|
# default in docker-compose.yml covers RTX 3090 (8.6) and RTX 4090 (8.9)
|
||||||
TORCH_CUDA_ARCH_LIST=7.5
|
TORCH_CUDA_ARCH_LIST=8.6;8.9+PTX
|
||||||
# the port the webui binds to on the host
|
# the port the webui binds to on the host
|
||||||
HOST_PORT=7860
|
HOST_PORT=7860
|
||||||
# the port the webui binds to inside the container
|
# the port the webui binds to inside the container
|
||||||
|
|
@ -19,6 +19,3 @@ APP_RUNTIME_GID=6972
|
||||||
# override default app build permissions (handy for deploying to cloud)
|
# override default app build permissions (handy for deploying to cloud)
|
||||||
#APP_GID=6972
|
#APP_GID=6972
|
||||||
#APP_UID=6972
|
#APP_UID=6972
|
||||||
# Set cache env
|
|
||||||
TRANSFORMERS_CACHE=/home/app/text-generation-webui/cache/
|
|
||||||
HF_HOME=/home/app/text-generation-webui/cache/
|
|
||||||
|
|
|
||||||
|
|
@ -1,27 +1,24 @@
|
||||||
FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-runtime
|
FROM nvidia/cuda:13.0.1-cudnn-runtime-ubuntu24.04
|
||||||
|
|
||||||
# Install Git
|
# Install Python 3.12, Git, and OpenMPI
|
||||||
RUN apt update && apt install -y git
|
RUN apt update && apt install -y python3.12 python3-pip git build-essential openmpi-bin libopenmpi-dev
|
||||||
|
|
||||||
# System-wide TensorRT-LLM requirements
|
|
||||||
RUN apt install -y openmpi-bin libopenmpi-dev
|
|
||||||
|
|
||||||
# Set the working directory
|
# Set the working directory
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Install text-generation-webui
|
|
||||||
RUN git clone https://github.com/oobabooga/text-generation-webui
|
|
||||||
WORKDIR /app/text-generation-webui
|
|
||||||
RUN pip install -r requirements.txt
|
|
||||||
|
|
||||||
# This is needed to avoid an error about "Failed to build mpi4py" in the next command
|
# This is needed to avoid an error about "Failed to build mpi4py" in the next command
|
||||||
ENV LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
|
ENV LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
|
||||||
|
|
||||||
|
# Install text-generation-webui
|
||||||
|
RUN git clone https://github.com/oobabooga/text-generation-webui
|
||||||
|
WORKDIR /app/text-generation-webui
|
||||||
|
RUN pip install --break-system-packages -r requirements/full/requirements.txt
|
||||||
|
|
||||||
# Install TensorRT-LLM
|
# Install TensorRT-LLM
|
||||||
RUN pip3 install tensorrt_llm==0.10.0 -U --pre --extra-index-url https://pypi.nvidia.com
|
RUN pip3 install --break-system-packages tensorrt_llm==1.1.0 --extra-index-url https://pypi.nvidia.com
|
||||||
|
|
||||||
# Expose the necessary port for the Python server
|
# Expose the necessary port for the Python server
|
||||||
EXPOSE 7860 5000
|
EXPOSE 7860 5000
|
||||||
|
|
||||||
# Run the Python server.py script with the specified command
|
# Run the Python server.py script with the specified command
|
||||||
CMD ["python", "server.py", "--api", "--listen"]
|
CMD ["python3", "server.py", "--api", "--listen"]
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
# BUILDER
|
# BUILDER
|
||||||
FROM ubuntu:22.04
|
FROM ubuntu:22.04
|
||||||
WORKDIR /builder
|
WORKDIR /builder
|
||||||
ARG TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX}"
|
|
||||||
ARG BUILD_EXTENSIONS="${BUILD_EXTENSIONS:-}"
|
ARG BUILD_EXTENSIONS="${BUILD_EXTENSIONS:-}"
|
||||||
ARG APP_UID="${APP_UID:-6972}"
|
ARG APP_UID="${APP_UID:-6972}"
|
||||||
ARG APP_GID="${APP_GID:-6972}"
|
ARG APP_GID="${APP_GID:-6972}"
|
||||||
|
|
@ -14,8 +13,7 @@ WORKDIR /home/app/
|
||||||
RUN git clone https://github.com/oobabooga/text-generation-webui.git
|
RUN git clone https://github.com/oobabooga/text-generation-webui.git
|
||||||
WORKDIR /home/app/text-generation-webui
|
WORKDIR /home/app/text-generation-webui
|
||||||
RUN GPU_CHOICE=B LAUNCH_AFTER_INSTALL=FALSE INSTALL_EXTENSIONS=TRUE ./start_linux.sh --verbose
|
RUN GPU_CHOICE=B LAUNCH_AFTER_INSTALL=FALSE INSTALL_EXTENSIONS=TRUE ./start_linux.sh --verbose
|
||||||
COPY /user_data/CMD_FLAGS.txt /home/app/text-generation-webui/user_data
|
EXPOSE ${CONTAINER_PORT:-7860} ${CONTAINER_API_PORT:-5000}
|
||||||
EXPOSE ${CONTAINER_PORT:-7860} ${CONTAINER_API_PORT:-5000} ${CONTAINER_API_STREAM_PORT:-5005}
|
|
||||||
WORKDIR /home/app/text-generation-webui
|
WORKDIR /home/app/text-generation-webui
|
||||||
# set umask to ensure group read / write at runtime
|
# set umask to ensure group read / write at runtime
|
||||||
CMD umask 0002 && export HOME=/home/app/text-generation-webui && ./start_linux.sh
|
CMD umask 0002 && export HOME=/home/app/text-generation-webui && ./start_linux.sh --listen
|
||||||
|
|
|
||||||
|
|
@ -4,22 +4,6 @@ services:
|
||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
args:
|
args:
|
||||||
# Requirements file to use:
|
|
||||||
# | GPU | requirements file to use |
|
|
||||||
# |--------|---------|
|
|
||||||
# | NVIDIA | `requirements.txt` |
|
|
||||||
# | AMD | `requirements_amd.txt` |
|
|
||||||
# | CPU only | `requirements_cpu_only.txt` |
|
|
||||||
# | Apple Intel | `requirements_apple_intel.txt` |
|
|
||||||
# | Apple Silicon | `requirements_apple_silicon.txt` |
|
|
||||||
# Default: requirements.txt`
|
|
||||||
# BUILD_REQUIREMENTS: requirements.txt
|
|
||||||
|
|
||||||
# Extension requirements to build:
|
|
||||||
# BUILD_EXTENSIONS:
|
|
||||||
|
|
||||||
# specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus
|
|
||||||
TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST:-7.5}
|
|
||||||
BUILD_EXTENSIONS: ${BUILD_EXTENSIONS:-}
|
BUILD_EXTENSIONS: ${BUILD_EXTENSIONS:-}
|
||||||
APP_GID: ${APP_GID:-6972}
|
APP_GID: ${APP_GID:-6972}
|
||||||
APP_UID: ${APP_UID:-6972}
|
APP_UID: ${APP_UID:-6972}
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,9 @@
|
||||||
# BUILDER
|
# BUILDER
|
||||||
FROM ubuntu:22.04
|
FROM ubuntu:22.04
|
||||||
WORKDIR /builder
|
WORKDIR /builder
|
||||||
ARG TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX}"
|
|
||||||
ARG BUILD_EXTENSIONS="${BUILD_EXTENSIONS:-}"
|
ARG BUILD_EXTENSIONS="${BUILD_EXTENSIONS:-}"
|
||||||
ARG APP_UID="${APP_UID:-6972}"
|
ARG APP_UID="${APP_UID:-6972}"
|
||||||
ARG APP_GID="${APP_GID:-6972}"
|
ARG APP_GID="${APP_GID:-6972}"
|
||||||
ARG GPU_CHOICE=A
|
|
||||||
ARG USE_CUDA118=FALSE
|
|
||||||
ARG LAUNCH_AFTER_INSTALL=FALSE
|
|
||||||
ARG INSTALL_EXTENSIONS=TRUE
|
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked,rw \
|
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked,rw \
|
||||||
apt update && \
|
apt update && \
|
||||||
|
|
@ -18,8 +13,7 @@ WORKDIR /home/app/
|
||||||
RUN git clone https://github.com/oobabooga/text-generation-webui.git
|
RUN git clone https://github.com/oobabooga/text-generation-webui.git
|
||||||
WORKDIR /home/app/text-generation-webui
|
WORKDIR /home/app/text-generation-webui
|
||||||
RUN GPU_CHOICE=N LAUNCH_AFTER_INSTALL=FALSE INSTALL_EXTENSIONS=TRUE ./start_linux.sh --verbose
|
RUN GPU_CHOICE=N LAUNCH_AFTER_INSTALL=FALSE INSTALL_EXTENSIONS=TRUE ./start_linux.sh --verbose
|
||||||
COPY CMD_FLAGS.txt /home/app/text-generation-webui/
|
EXPOSE ${CONTAINER_PORT:-7860} ${CONTAINER_API_PORT:-5000}
|
||||||
EXPOSE ${CONTAINER_PORT:-7860} ${CONTAINER_API_PORT:-5000} ${CONTAINER_API_STREAM_PORT:-5005}
|
|
||||||
# set umask to ensure group read / write at runtime
|
# set umask to ensure group read / write at runtime
|
||||||
WORKDIR /home/app/text-generation-webui
|
WORKDIR /home/app/text-generation-webui
|
||||||
CMD umask 0002 && export HOME=/home/app/text-generation-webui && ./start_linux.sh
|
CMD umask 0002 && export HOME=/home/app/text-generation-webui && ./start_linux.sh --listen
|
||||||
|
|
|
||||||
|
|
@ -4,22 +4,6 @@ services:
|
||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
args:
|
args:
|
||||||
# Requirements file to use:
|
|
||||||
# | GPU | requirements file to use |
|
|
||||||
# |--------|---------|
|
|
||||||
# | NVIDIA | `requirements.txt` |
|
|
||||||
# | AMD | `requirements_amd.txt` |
|
|
||||||
# | CPU only | `requirements_cpu_only.txt` |
|
|
||||||
# | Apple Intel | `requirements_apple_intel.txt` |
|
|
||||||
# | Apple Silicon | `requirements_apple_silicon.txt` |
|
|
||||||
# Default: requirements.txt`
|
|
||||||
# BUILD_REQUIREMENTS: requirements.txt
|
|
||||||
|
|
||||||
# Extension requirements to build:
|
|
||||||
# BUILD_EXTENSIONS:
|
|
||||||
|
|
||||||
# specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus
|
|
||||||
TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST:-7.5}
|
|
||||||
BUILD_EXTENSIONS: ${BUILD_EXTENSIONS:-}
|
BUILD_EXTENSIONS: ${BUILD_EXTENSIONS:-}
|
||||||
APP_GID: ${APP_GID:-6972}
|
APP_GID: ${APP_GID:-6972}
|
||||||
APP_UID: ${APP_UID:-6972}
|
APP_UID: ${APP_UID:-6972}
|
||||||
|
|
@ -31,14 +15,4 @@ services:
|
||||||
stdin_open: true
|
stdin_open: true
|
||||||
tty: true
|
tty: true
|
||||||
volumes:
|
volumes:
|
||||||
- ./cache:/home/app/text-generation-webui/cache
|
- ./user_data:/home/app/text-generation-webui/user_data
|
||||||
- ./characters:/home/app/text-generation-webui/characters
|
|
||||||
- ./extensions:/home/app/text-generation-webui/extensions
|
|
||||||
- ./loras:/home/app/text-generation-webui/loras
|
|
||||||
- ./logs:/home/app/text-generation-webui/logs
|
|
||||||
- ./models:/home/app/text-generation-webui/models
|
|
||||||
- ./presets:/home/app/text-generation-webui/presets
|
|
||||||
- ./prompts:/home/app/text-generation-webui/prompts
|
|
||||||
- ./softprompts:/home/app/text-generation-webui/softprompts
|
|
||||||
- ./training:/home/app/text-generation-webui/training
|
|
||||||
- ./cloudflared:/etc/cloudflared
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
# BUILDER
|
# BUILDER
|
||||||
FROM ubuntu:22.04
|
FROM ubuntu:22.04
|
||||||
WORKDIR /builder
|
WORKDIR /builder
|
||||||
ARG TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX}"
|
|
||||||
ARG BUILD_EXTENSIONS="${BUILD_EXTENSIONS:-}"
|
ARG BUILD_EXTENSIONS="${BUILD_EXTENSIONS:-}"
|
||||||
ARG APP_UID="${APP_UID:-6972}"
|
ARG APP_UID="${APP_UID:-6972}"
|
||||||
ARG APP_GID="${APP_GID:-6972}"
|
ARG APP_GID="${APP_GID:-6972}"
|
||||||
|
|
@ -14,8 +13,7 @@ WORKDIR /home/app/
|
||||||
RUN git clone https://github.com/oobabooga/text-generation-webui.git
|
RUN git clone https://github.com/oobabooga/text-generation-webui.git
|
||||||
WORKDIR /home/app/text-generation-webui
|
WORKDIR /home/app/text-generation-webui
|
||||||
RUN GPU_CHOICE=D LAUNCH_AFTER_INSTALL=FALSE INSTALL_EXTENSIONS=TRUE ./start_linux.sh --verbose
|
RUN GPU_CHOICE=D LAUNCH_AFTER_INSTALL=FALSE INSTALL_EXTENSIONS=TRUE ./start_linux.sh --verbose
|
||||||
COPY /user_data/CMD_FLAGS.txt /home/app/text-generation-webui/user_data
|
EXPOSE ${CONTAINER_PORT:-7860} ${CONTAINER_API_PORT:-5000}
|
||||||
EXPOSE ${CONTAINER_PORT:-7860} ${CONTAINER_API_PORT:-5000} ${CONTAINER_API_STREAM_PORT:-5005}
|
|
||||||
# set umask to ensure group read / write at runtime
|
# set umask to ensure group read / write at runtime
|
||||||
WORKDIR /home/app/text-generation-webui
|
WORKDIR /home/app/text-generation-webui
|
||||||
CMD umask 0002 && export HOME=/home/app/text-generation-webui && ./start_linux.sh
|
CMD umask 0002 && export HOME=/home/app/text-generation-webui && ./start_linux.sh --listen
|
||||||
|
|
|
||||||
|
|
@ -4,22 +4,6 @@ services:
|
||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
args:
|
args:
|
||||||
# Requirements file to use:
|
|
||||||
# | GPU | requirements file to use |
|
|
||||||
# |--------|---------|
|
|
||||||
# | NVIDIA | `requirements.txt` |
|
|
||||||
# | AMD | `requirements_amd.txt` |
|
|
||||||
# | CPU only | `requirements_cpu_only.txt` |
|
|
||||||
# | Apple Intel | `requirements_apple_intel.txt` |
|
|
||||||
# | Apple Silicon | `requirements_apple_silicon.txt` |
|
|
||||||
# Default: requirements.txt`
|
|
||||||
# BUILD_REQUIREMENTS: requirements.txt
|
|
||||||
|
|
||||||
# Extension requirements to build:
|
|
||||||
# BUILD_EXTENSIONS:
|
|
||||||
|
|
||||||
# specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus
|
|
||||||
TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST:-7.5}
|
|
||||||
BUILD_EXTENSIONS: ${BUILD_EXTENSIONS:-}
|
BUILD_EXTENSIONS: ${BUILD_EXTENSIONS:-}
|
||||||
APP_GID: ${APP_GID:-6972}
|
APP_GID: ${APP_GID:-6972}
|
||||||
APP_UID: ${APP_UID:-6972}
|
APP_UID: ${APP_UID:-6972}
|
||||||
|
|
|
||||||
|
|
@ -14,8 +14,7 @@ WORKDIR /home/app/
|
||||||
RUN git clone https://github.com/oobabooga/text-generation-webui.git
|
RUN git clone https://github.com/oobabooga/text-generation-webui.git
|
||||||
WORKDIR /home/app/text-generation-webui
|
WORKDIR /home/app/text-generation-webui
|
||||||
RUN GPU_CHOICE=A LAUNCH_AFTER_INSTALL=FALSE INSTALL_EXTENSIONS=TRUE ./start_linux.sh --verbose
|
RUN GPU_CHOICE=A LAUNCH_AFTER_INSTALL=FALSE INSTALL_EXTENSIONS=TRUE ./start_linux.sh --verbose
|
||||||
COPY /user_data/CMD_FLAGS.txt /home/app/text-generation-webui/user_data
|
EXPOSE ${CONTAINER_PORT:-7860} ${CONTAINER_API_PORT:-5000}
|
||||||
EXPOSE ${CONTAINER_PORT:-7860} ${CONTAINER_API_PORT:-5000} ${CONTAINER_API_STREAM_PORT:-5005}
|
|
||||||
WORKDIR /home/app/text-generation-webui
|
WORKDIR /home/app/text-generation-webui
|
||||||
# set umask to ensure group read / write at runtime
|
# set umask to ensure group read / write at runtime
|
||||||
CMD umask 0002 && export HOME=/home/app/text-generation-webui && ./start_linux.sh --listen
|
CMD umask 0002 && export HOME=/home/app/text-generation-webui && ./start_linux.sh --listen
|
||||||
|
|
|
||||||
|
|
@ -4,22 +4,8 @@ services:
|
||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
args:
|
args:
|
||||||
# Requirements file to use:
|
|
||||||
# | GPU | requirements file to use |
|
|
||||||
# |--------|---------|
|
|
||||||
# | NVIDIA | `requirements.txt` |
|
|
||||||
# | AMD | `requirements_amd.txt` |
|
|
||||||
# | CPU only | `requirements_cpu_only.txt` |
|
|
||||||
# | Apple Intel | `requirements_apple_intel.txt` |
|
|
||||||
# | Apple Silicon | `requirements_apple_silicon.txt` |
|
|
||||||
# Default: requirements.txt`
|
|
||||||
# BUILD_REQUIREMENTS: requirements.txt
|
|
||||||
|
|
||||||
# Extension requirements to build:
|
|
||||||
# BUILD_EXTENSIONS:
|
|
||||||
|
|
||||||
# specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus
|
# specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus
|
||||||
TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST:-7.5}
|
TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST:-8.6;8.9+PTX}
|
||||||
BUILD_EXTENSIONS: ${BUILD_EXTENSIONS:-}
|
BUILD_EXTENSIONS: ${BUILD_EXTENSIONS:-}
|
||||||
APP_GID: ${APP_GID:-6972}
|
APP_GID: ${APP_GID:-6972}
|
||||||
APP_UID: ${APP_UID:-6972}
|
APP_UID: ${APP_UID:-6972}
|
||||||
|
|
|
||||||
|
|
@ -2,31 +2,44 @@ Used to have multi-turn conversations with the model.
|
||||||
|
|
||||||
## Input area
|
## Input area
|
||||||
|
|
||||||
The following buttons can be found. Note that the hover menu can be replaced with always-visible buttons with the `--chat-buttons` flag.
|
The main action buttons are:
|
||||||
|
|
||||||
* **Generate**: sends your message and makes the model start a reply.
|
* **Send**: sends your message and makes the model start a reply.
|
||||||
* **Stop**: stops an ongoing generation as soon as the next token is generated (which can take a while for a slow model).
|
* **Stop**: stops an ongoing generation as soon as the next token is generated (which can take a while for a slow model).
|
||||||
|
|
||||||
|
The hover menu (☰) that appears over the chat area contains:
|
||||||
|
|
||||||
|
* **Regenerate**: similar to Send, but your last message is used as input instead of the text in the input field. Note that if the temperature/top_p/top_k parameters are low in the "Parameters" tab of the UI, the new reply may end up identical to the previous one.
|
||||||
* **Continue**: makes the model attempt to continue the existing reply. In some cases, the model may simply end the existing turn immediately without generating anything new, but in other cases, it may generate a longer reply.
|
* **Continue**: makes the model attempt to continue the existing reply. In some cases, the model may simply end the existing turn immediately without generating anything new, but in other cases, it may generate a longer reply.
|
||||||
* **Regenerate**: similar to Generate, but your last message is used as input instead of the text in the input field. Note that if the temperature/top_p/top_k parameters are low in the "Parameters" tab of the UI, the new reply may end up identical to the previous one.
|
|
||||||
* **Remove last reply**: removes the last input/output pair from the history and sends your last message back into the input field.
|
* **Remove last reply**: removes the last input/output pair from the history and sends your last message back into the input field.
|
||||||
* **Replace last reply**: replaces the last reply with whatever you typed into the input field. Useful in conjunction with "Copy last reply" if you want to edit the bot response.
|
|
||||||
* **Copy last reply**: sends the contents of the bot's last reply to the input field.
|
|
||||||
* **Impersonate**: makes the model generate a new message on your behalf in the input field, taking into consideration the existing chat history.
|
* **Impersonate**: makes the model generate a new message on your behalf in the input field, taking into consideration the existing chat history.
|
||||||
* **Send dummy message**: adds a new message to the chat history without causing the model to generate a reply.
|
* **Send dummy message**: adds a new message to the chat history without causing the model to generate a reply.
|
||||||
* **Send dummy reply**: adds a new reply to the chat history as if the model had generated this reply. Useful in conjunction with "Send dummy message".
|
* **Send dummy reply**: adds a new reply to the chat history as if the model had generated this reply. Useful in conjunction with "Send dummy message".
|
||||||
* **Start new chat**: starts a new conversation while keeping the old one saved. If you are talking to a character that has a "Greeting" message defined, this message will be automatically added to the new history.
|
* **Send to Notebook**: sends the entire chat prompt up to now to the Notebook tab.
|
||||||
* **Send to default**: sends the entire chat prompt up to now to the "Default" tab.
|
* **Show controls**: checkbox that toggles the visibility of the sidebar controls (Start reply with, Mode, Chat style, etc.). Shortcut: Ctrl+S.
|
||||||
* **Send to notebook**: sends the entire chat prompt up to now to the "Notebook" tab.
|
|
||||||
|
|
||||||
The **Show controls** checkbox causes the input fields below the input textbox to disappear. It is useful for making the page fit entirely into view and not scroll.
|
|
||||||
|
|
||||||
## Past chats
|
## Past chats
|
||||||
|
|
||||||
Allows you to switch between the current and previous conversations with the current character, or between the current and previous instruct conversations (if in "instruct" mode). The **Rename** menu can be used to give a unique name to the selected conversation, and the 🗑️ button allows you to delete it.
|
Allows you to switch between the current and previous conversations with the current character, or between the current and previous instruct conversations (if in "instruct" mode). The available buttons are:
|
||||||
|
|
||||||
## Start reply with
|
* **Branch**: creates a branch of the current conversation at a specific message.
|
||||||
|
* **Rename**: allows you to give a unique name to the selected conversation.
|
||||||
|
* **🗑️**: deletes the selected conversation.
|
||||||
|
* **New chat**: starts a new conversation. If you are talking to a character that has a "Greeting" message defined, this message will be automatically added to the new history.
|
||||||
|
|
||||||
Whatever you type there will appear at the start of every reply by the bot. This is useful to guide the response in the desired direction.
|
A search field is also available to filter conversations by name.
|
||||||
|
|
||||||
|
## Sidebar controls
|
||||||
|
|
||||||
|
The sidebar (toggled via "Show controls") contains:
|
||||||
|
|
||||||
|
* **Start reply with**: whatever you type there will appear at the start of every reply by the bot. This is useful to guide the response in the desired direction.
|
||||||
|
* **Reasoning effort**: controls the thinking depth for models that support it. Options: low, medium, high.
|
||||||
|
* **Enable thinking**: enables extended thinking mode for models that support it.
|
||||||
|
* **Activate web search**: when enabled, the model can search the web for information before replying. You can also set the number of pages to download.
|
||||||
|
* **Mode**: see below.
|
||||||
|
* **Chat style**: see below.
|
||||||
|
* **Command for chat-instruct mode**: the command that is used in chat-instruct mode to query the model to generate a reply on behalf of the character. Can be used creatively to generate specific kinds of responses. Inside this string, `<|character|>` is a placeholder that gets replaced with the bot name, and `<|prompt|>` is a placeholder that gets replaced with the full chat prompt.
|
||||||
|
|
||||||
## Mode
|
## Mode
|
||||||
|
|
||||||
|
|
@ -73,7 +86,7 @@ Now that an instruction-following model is defined, we can move on to describing
|
||||||
|
|
||||||
### Chat
|
### Chat
|
||||||
|
|
||||||
Used for talking to the character defined under "Parameters" > "Character" using a simple chat prompt in this format:
|
Used for talking to the character defined under "Character" tab using a simple chat prompt in this format:
|
||||||
|
|
||||||
```
|
```
|
||||||
Chiharu Yamada's Persona: Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology.
|
Chiharu Yamada's Persona: Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology.
|
||||||
|
|
@ -83,7 +96,7 @@ You: How are you?
|
||||||
Chiharu Yamada: I'm doing well, thank you for asking! Is there something specific you would like to talk about or ask me? I'm here to help answer any questions you may have.
|
Chiharu Yamada: I'm doing well, thank you for asking! Is there something specific you would like to talk about or ask me? I'm here to help answer any questions you may have.
|
||||||
```
|
```
|
||||||
|
|
||||||
There are 3 adjustable parameters in "Parameters" > "Character" being used in this prompt:
|
There are 3 adjustable parameters in the "Character" tab being used in this prompt:
|
||||||
|
|
||||||
* The **Context** string appears at the top of the prompt. Most often it describes the bot's personality and adds a few example messages to guide the model towards the desired reply length and format. This string never gets truncated: as the prompt size increases, old messages get removed one at a time until the prompt becomes smaller than the truncation length set under "Parameters" > "Generation" > "Truncate the prompt up to this length".
|
* The **Context** string appears at the top of the prompt. Most often it describes the bot's personality and adds a few example messages to guide the model towards the desired reply length and format. This string never gets truncated: as the prompt size increases, old messages get removed one at a time until the prompt becomes smaller than the truncation length set under "Parameters" > "Generation" > "Truncate the prompt up to this length".
|
||||||
* The **Your name** string appears at the beginning of each user reply. By default, this string is "You".
|
* The **Your name** string appears at the beginning of each user reply. By default, this string is "You".
|
||||||
|
|
@ -99,7 +112,7 @@ Used for talking to an instruction-following model using the prompt format defin
|
||||||
|
|
||||||
The prompt format is defined by the **Instruction template** parameter in "Parameters" > "Instruction template", which represents a Jinja2 template.
|
The prompt format is defined by the **Instruction template** parameter in "Parameters" > "Instruction template", which represents a Jinja2 template.
|
||||||
|
|
||||||
Note that when you load a model in the "Model" tab, the web UI will try to automatically detect its instruction template (if any), and will update the values under "Parameters" > "Instruction template" accordingly. This is done using a set of regular expressions defined in `models/config.yaml`. This detection is not guaranteed to be accurate. You should check the model card on Hugging Face to see if you are using the correct prompt format.
|
Note that when you load a model in the "Model" tab, the web UI will try to automatically detect its instruction template (if any), and will update the values under "Parameters" > "Instruction template" accordingly. This is done using a set of regular expressions defined in `user_data/models/config.yaml`. This detection is not guaranteed to be accurate. You should check the model card on Hugging Face to see if you are using the correct prompt format.
|
||||||
|
|
||||||
### Chat-instruct
|
### Chat-instruct
|
||||||
|
|
||||||
|
|
@ -127,8 +140,6 @@ Here, the command is
|
||||||
|
|
||||||
Below this command, the regular chat prompt is added, including its Context string and the chat history, and then the user turn ends. The bot turn starts with the "Character's name" string followed by `:`, thus prompting the instruction-following model to write a single reply for the character.
|
Below this command, the regular chat prompt is added, including its Context string and the chat history, and then the user turn ends. The bot turn starts with the "Character's name" string followed by `:`, thus prompting the instruction-following model to write a single reply for the character.
|
||||||
|
|
||||||
The chat-instruct command can be customized under "Parameters" > "Instruction template" > "Command for chat-instruct mode". Inside that command string, `<|character|>` is a placeholder that gets replaced with the bot name, and `<|prompt|>` is a placeholder that gets replaced with the full chat prompt.
|
|
||||||
|
|
||||||
Note that you can get creative: instead of writing something trivial like "Write a single reply for the character", you could add more complex instructions like
|
Note that you can get creative: instead of writing something trivial like "Write a single reply for the character", you could add more complex instructions like
|
||||||
|
|
||||||
> This is an adventure game, and your task is to write a reply in name of "<|character|>" where 3 options are given for the user to then choose from.
|
> This is an adventure game, and your task is to write a reply in name of "<|character|>" where 3 options are given for the user to then choose from.
|
||||||
|
|
@ -145,4 +156,4 @@ The styles are only applied to chat and chat-instruct modes. Instruct mode has i
|
||||||
|
|
||||||
## Character gallery
|
## Character gallery
|
||||||
|
|
||||||
This menu is a built-in extension defined under `text-generation-webui/extensions/gallery`. It displays a gallery with your characters, and if you click on a character, it will be automatically selected in the menu under "Parameters" > "Character".
|
This menu is a built-in extension defined under `text-generation-webui/extensions/gallery`. It displays a gallery with your characters, and if you click on a character, it will be automatically selected in the Character tab.
|
||||||
|
|
|
||||||
|
|
@ -10,11 +10,11 @@ The number on the lower right of the Input box counts the number of tokens in th
|
||||||
|
|
||||||
Below the Input box, the following buttons can be found:
|
Below the Input box, the following buttons can be found:
|
||||||
|
|
||||||
|
* **Continue**: starts a new generation taking as input the text in the "Output" box.
|
||||||
* **Generate**: starts a new generation.
|
* **Generate**: starts a new generation.
|
||||||
* **Stop**: stops an ongoing generation as soon as the next token is generated (which can take a while for a slow model).
|
* **Stop**: stops an ongoing generation as soon as the next token is generated (which can take a while for a slow model).
|
||||||
* **Continue**: starts a new generation taking as input the text in the "Output" box.
|
|
||||||
|
|
||||||
In the **Prompt** menu, you can select from some predefined prompts defined under `text-generation-webui/prompts`. The 💾 button saves your current input as a new prompt, the 🗑️ button deletes the selected prompt, and the 🔄 button refreshes the list. If you come up with an interesting prompt for a certain task, you are welcome to submit it to the repository.
|
In the **Prompt** menu, you can select from saved prompts stored in `user_data/logs/notebook`. The **New** button creates a new prompt, the **Rename** button renames the selected prompt, and the 🗑️ button deletes it. The 🔄 button refreshes the list.
|
||||||
|
|
||||||
### Output
|
### Output
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -43,9 +43,15 @@ For more information about the parameters, the [transformers documentation](http
|
||||||
* **presence_penalty**: Similar to repetition_penalty, but with an additive offset on the raw token scores instead of a multiplicative factor. It may generate better results. 0 means no penalty, higher value = less repetition, lower value = more repetition. Previously called "additive_repetition_penalty".
|
* **presence_penalty**: Similar to repetition_penalty, but with an additive offset on the raw token scores instead of a multiplicative factor. It may generate better results. 0 means no penalty, higher value = less repetition, lower value = more repetition. Previously called "additive_repetition_penalty".
|
||||||
* **frequency_penalty**: Repetition penalty that scales based on how many times the token has appeared in the context. Be careful with this; there's no limit to how much a token can be penalized.
|
* **frequency_penalty**: Repetition penalty that scales based on how many times the token has appeared in the context. Be careful with this; there's no limit to how much a token can be penalized.
|
||||||
* **repetition_penalty_range**: The number of most recent tokens to consider for repetition penalty. 0 makes all tokens be used.
|
* **repetition_penalty_range**: The number of most recent tokens to consider for repetition penalty. 0 makes all tokens be used.
|
||||||
|
* **dry_multiplier**: Set to greater than 0 to enable DRY (Don't Repeat Yourself) sampling. It penalizes tokens that would extend a sequence that already appeared in the context. Recommended value: 0.8.
|
||||||
|
* **dry_allowed_length**: The longest sequence that can be repeated without being penalized by DRY. Shorter values make DRY more aggressive.
|
||||||
|
* **dry_base**: Controls how fast the DRY penalty grows with increasing sequence length.
|
||||||
* **typical_p**: If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text.
|
* **typical_p**: If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text.
|
||||||
* **tfs**: Tries to detect a tail of low-probability tokens in the distribution and removes those tokens. See [this blog post](https://www.trentonbricken.com/Tail-Free-Sampling/) for details. The closer to 0, the more discarded tokens.
|
* **tfs**: Tries to detect a tail of low-probability tokens in the distribution and removes those tokens. See [this blog post](https://www.trentonbricken.com/Tail-Free-Sampling/) for details. The closer to 0, the more discarded tokens.
|
||||||
* **top_a**: Tokens with probability smaller than `(top_a) * (probability of the most likely token)^2` are discarded.
|
* **top_a**: Tokens with probability smaller than `(top_a) * (probability of the most likely token)^2` are discarded.
|
||||||
|
* **top_n_sigma**: Keeps only tokens within N standard deviations of the mean log-probability. Acts as an adaptive cutoff that adjusts to the shape of the distribution. 0 disables it.
|
||||||
|
* **xtc_threshold**: eXclusion from Top Choices (XTC) sampling. If 2 or more tokens have probability above this threshold, the top token may be removed. This encourages the model to use less common word choices and can increase creativity.
|
||||||
|
* **xtc_probability**: The probability that XTC removal will actually happen when the threshold condition is met. Set to 1 for it to always apply, or lower for occasional application.
|
||||||
* **epsilon_cutoff**: In units of 1e-4; a reasonable value is 3. This sets a probability floor below which tokens are excluded from being sampled.
|
* **epsilon_cutoff**: In units of 1e-4; a reasonable value is 3. This sets a probability floor below which tokens are excluded from being sampled.
|
||||||
* **eta_cutoff**: In units of 1e-4; a reasonable value is 3. The main parameter of the special Eta Sampling technique. See [this paper](https://arxiv.org/pdf/2210.15191.pdf) for a description.
|
* **eta_cutoff**: In units of 1e-4; a reasonable value is 3. The main parameter of the special Eta Sampling technique. See [this paper](https://arxiv.org/pdf/2210.15191.pdf) for a description.
|
||||||
* **guidance_scale**: The main parameter for Classifier-Free Guidance (CFG). [The paper](https://arxiv.org/pdf/2306.17806.pdf) suggests that 1.5 is a good value. It can be used in conjunction with a negative prompt or not.
|
* **guidance_scale**: The main parameter for Classifier-Free Guidance (CFG). [The paper](https://arxiv.org/pdf/2306.17806.pdf) suggests that 1.5 is a good value. It can be used in conjunction with a negative prompt or not.
|
||||||
|
|
@ -55,36 +61,62 @@ For more information about the parameters, the [transformers documentation](http
|
||||||
*Note: Use either mirostat or dynamic_temperature, not both at the same time.*
|
*Note: Use either mirostat or dynamic_temperature, not both at the same time.*
|
||||||
* **mirostat_tau**: Target perplexity for Mirostat sampling. Controls how “surprising” the text is. Higher values = more diverse, lower = more predictable. Preset Arena suggests 8 as a good value.
|
* **mirostat_tau**: Target perplexity for Mirostat sampling. Controls how “surprising” the text is. Higher values = more diverse, lower = more predictable. Preset Arena suggests 8 as a good value.
|
||||||
* **mirostat_eta**: Learning rate for Mirostat’s perplexity adjustment. Higher values = adapts faster but less stable, lower values = slower but more stable. Preset Arena suggests 0.1 as a good value.
|
* **mirostat_eta**: Learning rate for Mirostat’s perplexity adjustment. Higher values = adapts faster but less stable, lower values = slower but more stable. Preset Arena suggests 0.1 as a good value.
|
||||||
|
* **adaptive_target**: Target probability for adaptive-p sampling. This method adjusts the sampling threshold dynamically based on an exponential moving average of recent token probabilities. 0 disables it.
|
||||||
|
* **adaptive_decay**: EMA decay rate for adaptive-p sampling. Controls how quickly the running average adjusts. Default: 0.9.
|
||||||
* **dynamic_temperature**: Activates Dynamic Temperature. This modifies temperature to range between "dynatemp_low" (minimum) and "dynatemp_high" (maximum), with an entropy-based scaling. The steepness of the curve is controlled by "dynatemp_exponent".
|
* **dynamic_temperature**: Activates Dynamic Temperature. This modifies temperature to range between "dynatemp_low" (minimum) and "dynatemp_high" (maximum), with an entropy-based scaling. The steepness of the curve is controlled by "dynatemp_exponent".
|
||||||
*Note: Use either dynamic_temperature or mirostat, not both at the same time.*
|
*Note: Use either dynamic_temperature or mirostat, not both at the same time.*
|
||||||
* **smoothing_factor**: Activates Quadratic Sampling. When `0 < smoothing_factor < 1`, the logits distribution becomes flatter. When `smoothing_factor > 1`, it becomes more peaked.
|
* **smoothing_factor**: Activates Quadratic Sampling. When `0 < smoothing_factor < 1`, the logits distribution becomes flatter. When `smoothing_factor > 1`, it becomes more peaked.
|
||||||
|
* **smoothing_curve**: Adjusts the dropoff curve of Quadratic Sampling. Higher values make the curve steeper. Only takes effect when smoothing_factor is set.
|
||||||
* **temperature_last**: Makes temperature the last sampler instead of the first. With this, you can remove low probability tokens with a sampler like min_p and then use a high temperature to make the model creative without losing coherency. Note: this parameter takes precedence over "Sampler priority". That means that `temperature`/`dynamic_temperature`/`quadratic_sampling` will be removed from wherever they are and moved to the end of the stack.
|
* **temperature_last**: Makes temperature the last sampler instead of the first. With this, you can remove low probability tokens with a sampler like min_p and then use a high temperature to make the model creative without losing coherency. Note: this parameter takes precedence over "Sampler priority". That means that `temperature`/`dynamic_temperature`/`quadratic_sampling` will be removed from wherever they are and moved to the end of the stack.
|
||||||
* **do_sample**: When unchecked, sampling is entirely disabled, and greedy decoding is used instead (the most likely token is always picked).
|
* **do_sample**: When unchecked, sampling is entirely disabled, and greedy decoding is used instead (the most likely token is always picked).
|
||||||
* **Seed**: Set the Pytorch seed to this number. Note that some loaders do not use Pytorch (notably llama.cpp), and others are not deterministic (ExLlamaV2). For these loaders, the seed has no effect.
|
* **Seed**: Set the Pytorch seed to this number. Note that some loaders do not use Pytorch (notably llama.cpp). For these loaders, the seed has no effect.
|
||||||
* **encoder_repetition_penalty**: Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge.
|
* **encoder_repetition_penalty**: Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge.
|
||||||
* **no_repeat_ngram_size**: If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases.
|
* **no_repeat_ngram_size**: If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases.
|
||||||
|
|
||||||
To the right (or below if you are on mobile), the following parameters are present:
|
To the right (or below if you are on mobile), the following parameters are present:
|
||||||
|
|
||||||
* **Truncate the prompt up to this length**: Used to prevent the prompt from getting bigger than the model's context length. In the case of the transformers loader, which allocates memory dynamically, this parameter can also be used to set a VRAM ceiling and prevent out-of-memory errors. This parameter is automatically updated with the model's context length (from "n_ctx" or "max_seq_len" for loaders that use these parameters, and from the model metadata directly for loaders that do not) when you load a model.
|
* **Truncate the prompt up to this length**: Used to prevent the prompt from getting bigger than the model's context length. In the case of the transformers loader, which allocates memory dynamically, this parameter can also be used to set a VRAM ceiling and prevent out-of-memory errors. This parameter is automatically updated with the model's context length (from "ctx_size" for loaders that use this parameter, and from the model metadata directly for loaders that do not) when you load a model.
|
||||||
* **Maximum number of tokens/second**: to make text readable in real-time in case the model is generating too fast. Good if you want to flex and tell everyone how good your GPU is.
|
* **Maximum number of tokens/second**: to make text readable in real-time in case the model is generating too fast. Good if you want to flex and tell everyone how good your GPU is.
|
||||||
|
* **Custom system message**: If not empty, will be used instead of the default system message in the instruction template. Useful for customizing the personality of the chatbot. Example: "You are a duck."
|
||||||
* **Custom stopping strings**: The model stops generating as soon as any of the strings set in this field is generated. Note that when generating text in the Chat tab, some default stopping strings are set regardless of this parameter, like "\nYour Name:" and "\nBot name:" for chat mode. That's why this parameter has a "Custom" in its name.
|
* **Custom stopping strings**: The model stops generating as soon as any of the strings set in this field is generated. Note that when generating text in the Chat tab, some default stopping strings are set regardless of this parameter, like "\nYour Name:" and "\nBot name:" for chat mode. That's why this parameter has a "Custom" in its name.
|
||||||
* **Custom token bans**: Allows you to ban the model from generating certain tokens altogether. You need to find the token IDs under "Default" > "Tokens" or "Notebook" > "Tokens", or by looking at the `tokenizer.json` for the model directly.
|
* **Custom token bans**: Allows you to ban the model from generating certain tokens altogether. You need to find the token IDs under "Default" > "Tokens" or "Notebook" > "Tokens", or by looking at the `tokenizer.json` for the model directly.
|
||||||
* **auto_max_new_tokens**: When checked, the max_new_tokens parameter is expanded in the backend to the available context length. The maximum length is given by the "truncation_length" parameter. This is useful for getting long replies in the Chat tab without having to click on "Continue" many times.
|
* **auto_max_new_tokens**: When checked, the max_new_tokens parameter is expanded in the backend to the available context length. The maximum length is given by the "truncation_length" parameter. This is useful for getting long replies in the Chat tab without having to click on "Continue" many times.
|
||||||
* **Ban the eos_token**: One of the possible tokens that a model can generate is the EOS (End of Sequence) token. When it is generated, the generation stops prematurely. When this parameter is checked, that token is banned from being generated, and the generation will always generate "max_new_tokens" tokens.
|
* **Ban the eos_token**: One of the possible tokens that a model can generate is the EOS (End of Sequence) token. When it is generated, the generation stops prematurely. When this parameter is checked, that token is banned from being generated, and the generation will always generate "max_new_tokens" tokens.
|
||||||
* **Add the bos_token to the beginning of prompts**: By default, the tokenizer will add a BOS (Beginning of Sequence) token to your prompt. During training, BOS tokens are used to separate different documents. If unchecked, no BOS token will be added, and the model will interpret your prompt as being in the middle of a document instead of at the start of one. This significantly changes the output and can make it more creative.
|
* **Add the bos_token to the beginning of prompts**: By default, the tokenizer will add a BOS (Beginning of Sequence) token to your prompt. During training, BOS tokens are used to separate different documents. If unchecked, no BOS token will be added, and the model will interpret your prompt as being in the middle of a document instead of at the start of one. This significantly changes the output and can make it more creative.
|
||||||
* **Skip special tokens**: When decoding the generated tokens, skip special tokens from being converted to their text representation. Otherwise, BOS appears as `<s>`, EOS as `</s>`, etc.
|
* **Skip special tokens**: When decoding the generated tokens, skip special tokens from being converted to their text representation. Otherwise, BOS appears as `<s>`, EOS as `</s>`, etc.
|
||||||
|
* **prompt_lookup_num_tokens**: Activates Prompt Lookup Decoding, a form of speculative decoding for the Transformers loader. It guesses future tokens by looking for matching patterns in the prompt itself, which can speed up generation for tasks that involve repeating or paraphrasing parts of the input.
|
||||||
* **Activate text streaming**: When unchecked, the full response is outputted at once, without streaming the words one at a time. I recommend unchecking this parameter on high latency networks like running the webui on Google Colab or using `--share`.
|
* **Activate text streaming**: When unchecked, the full response is outputted at once, without streaming the words one at a time. I recommend unchecking this parameter on high latency networks like running the webui on Google Colab or using `--share`.
|
||||||
|
* **Static KV cache**: Use a static cache for improved performance with the Transformers loader. May not be compatible with all models.
|
||||||
* **Sampler priority**: Allows you to customize the order in which the different samplers are applied. The first sampler on the list gets applied first. With this, custom orders like `top_p -> temperature -> top_k` can be defined.
|
* **Sampler priority**: Allows you to customize the order in which the different samplers are applied. The first sampler on the list gets applied first. With this, custom orders like `top_p -> temperature -> top_k` can be defined.
|
||||||
* **Load grammar from file**: Loads a GBNF grammar from a file under `text-generation-webui/grammars`. The output is written to the "Grammar" box below. You can also save and delete custom grammars using this menu.
|
* **DRY sequence breakers**: Tokens across which DRY sequence matching is not continued. Typically punctuation and special tokens. Only used when DRY is active (dry_multiplier > 0).
|
||||||
|
* **Load grammar from file**: Loads a GBNF grammar from a file under `user_data/grammars`. The output is written to the "Grammar" box below. You can also save and delete custom grammars using this menu.
|
||||||
* **Grammar**: Allows you to constrain the model output to a particular format. For instance, you can make the model generate lists, JSON, specific words, etc. Grammar is extremely powerful and I highly recommend it. The syntax looks a bit daunting at first sight, but it gets very easy once you understand it. See the [GBNF Guide](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md) for details.
|
* **Grammar**: Allows you to constrain the model output to a particular format. For instance, you can make the model generate lists, JSON, specific words, etc. Grammar is extremely powerful and I highly recommend it. The syntax looks a bit daunting at first sight, but it gets very easy once you understand it. See the [GBNF Guide](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md) for details.
|
||||||
|
|
||||||
## Character
|
### Chat tab controls
|
||||||
|
|
||||||
Parameters that define the character that is used in the Chat tab when "chat" or "chat-instruct" are selected under "Mode".
|
The following parameters appear in the Chat tab sidebar rather than the Parameters tab:
|
||||||
|
|
||||||
* **Character**: A dropdown menu where you can select from saved characters, save a new character (💾 button), and delete the selected character (🗑️).
|
* **reasoning_effort**: Controls the thinking depth for models that support it (used by GPT-OSS). Options: low, medium, high.
|
||||||
* **Your name**: Your name as it appears in the prompt.
|
* **enable_thinking**: Enables extended thinking mode for models that support it (used by Seed-OSS and pre-2507 Qwen3). When enabled, the model can use a thinking step before generating its reply.
|
||||||
|
|
||||||
|
## Instruction template
|
||||||
|
|
||||||
|
This sub-tab within the Parameters tab defines the instruction template used in the Chat tab when "instruct" or "chat-instruct" are selected under "Mode".
|
||||||
|
|
||||||
|
* **Saved instruction templates**: A dropdown menu where you can select a template. Click **Load** to apply it. The 💾 button saves the current template, and the 🗑️ button deletes the selected one.
|
||||||
|
* **Instruction template**: A Jinja2 template that defines the prompt format for the instruction-following conversation.
|
||||||
|
* **Send to notebook**: Send the full instruction template in string format to the Notebook tab.
|
||||||
|
* **Chat template**: A Jinja2 template that defines the prompt format for regular chat conversations with characters.
|
||||||
|
|
||||||
|
## Character tab
|
||||||
|
|
||||||
|
The Character tab is a separate top-level tab that contains the following sub-tabs:
|
||||||
|
|
||||||
|
### Character
|
||||||
|
|
||||||
|
Parameters that define the character used in the Chat tab when "chat" or "chat-instruct" are selected under "Mode".
|
||||||
|
|
||||||
|
* **Character**: A dropdown menu where you can select from saved characters, save a new character (💾 button), and delete the selected character (🗑️). The **Restore character** button resets the character to its last saved state.
|
||||||
* **Character's name**: The bot name as it appears in the prompt.
|
* **Character's name**: The bot name as it appears in the prompt.
|
||||||
* **Context**: A string that is always at the top of the prompt. It never gets truncated. It usually defines the bot's personality and some key elements of the conversation.
|
* **Context**: A string that is always at the top of the prompt. It never gets truncated. It usually defines the bot's personality and some key elements of the conversation.
|
||||||
* **Greeting**: An opening message for the bot. When set, it appears whenever you start a new chat.
|
* **Greeting**: An opening message for the bot. When set, it appears whenever you start a new chat.
|
||||||
|
|
@ -98,31 +130,26 @@ Note: the following replacements take place in the context and greeting fields w
|
||||||
|
|
||||||
So you can use those special placeholders in your character definitions. They are commonly found in TavernAI character cards.
|
So you can use those special placeholders in your character definitions. They are commonly found in TavernAI character cards.
|
||||||
|
|
||||||
## Instruction template
|
### User
|
||||||
|
|
||||||
Defines the instruction template that is used in the Chat tab when "instruct" or "chat-instruct" are selected under "Mode".
|
Allows you to create and manage user profiles.
|
||||||
|
|
||||||
* **Saved instruction templates**: A dropdown menu where you can load a saved template, save a new template (💾 button), and delete the currently selected template (🗑️).
|
* **User**: A dropdown to select, save (💾), or delete (🗑️) user profiles.
|
||||||
* **Custom system message**: A message that defines the personality of the chatbot, replacing its default "System message" string. Example: "You are a duck."
|
* **Name**: Your name as it appears in the prompt.
|
||||||
* **Instruction template**: A Jinja2 template that defines the prompt format for the instruction-following conversation.
|
* **Description**: An optional description of yourself that can be referenced in conversations.
|
||||||
* **Send to default**: Send the full instruction template in string format to the Default tab.
|
|
||||||
* **Send to notebook**: Send the full instruction template in string format to the Notebook tab.
|
|
||||||
* **Send to negative prompt**: Send the full instruction template in string format to the "Negative prompt" field under "Parameters" > "Generation".
|
|
||||||
* **Chat template**: A Jinja2 template that defines the prompt format for regular chat conversations with characters.
|
|
||||||
* **Command for chat-instruct mode**: The command that is used in chat-instruct mode to query the model to generate a reply on behalf of the character. Can be used creatively to generate specific kinds of responses.
|
|
||||||
|
|
||||||
## Chat history
|
### Chat history
|
||||||
|
|
||||||
In this tab, you can download the current chat history in JSON format and upload a previously saved chat history.
|
In this tab, you can download the current chat history in JSON format and upload a previously saved chat history.
|
||||||
|
|
||||||
When a history is uploaded, a new chat is created to hold it. That is, you don't lose your current chat in the Chat tab.
|
When a history is uploaded, a new chat is created to hold it. That is, you don't lose your current chat in the Chat tab.
|
||||||
|
|
||||||
## Upload character
|
### Upload character
|
||||||
|
|
||||||
### YAML or JSON
|
#### YAML or JSON
|
||||||
|
|
||||||
Allows you to upload characters in the YAML format used by the web UI, including optionally a profile picture.
|
Allows you to upload characters in the YAML format used by the web UI, including optionally a profile picture.
|
||||||
|
|
||||||
### TavernAI PNG
|
#### TavernAI PNG
|
||||||
|
|
||||||
Allows you to upload a TavernAI character card. It will be converted to the internal YAML format of the web UI after upload.
|
Allows you to upload a TavernAI character card. It will be converted to the internal YAML format of the web UI after upload.
|
||||||
|
|
|
||||||
|
|
@ -2,112 +2,89 @@ This is where you load models, apply LoRAs to a loaded model, and download new m
|
||||||
|
|
||||||
## Model loaders
|
## Model loaders
|
||||||
|
|
||||||
### Transformers
|
|
||||||
|
|
||||||
Loads: full precision (16-bit or 32-bit) models. The repository usually has a clean name without GGUF, EXL2, GPTQ, or AWQ in its name, and the model files are named `pytorch_model.bin` or `model.safetensors`.
|
|
||||||
|
|
||||||
Example: [https://huggingface.co/lmsys/vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5).
|
|
||||||
|
|
||||||
Full precision models use a ton of VRAM, so you will usually want to select the "load_in_4bit" and "use_double_quant" options to load the model in 4-bit precision using bitsandbytes.
|
|
||||||
|
|
||||||
This loader can also load GPTQ models and train LoRAs with them. For that, make sure to check the "auto-devices" and "disable_exllama" options before loading the model.
|
|
||||||
|
|
||||||
Options:
|
|
||||||
|
|
||||||
* **gpu-memory**: When set to greater than 0, activates CPU offloading using the accelerate library, where part of the layers go to the CPU. The performance is very bad. Note that accelerate doesn't treat this parameter very literally, so if you want the VRAM usage to be at most 10 GiB, you may need to set this parameter to 9 GiB or 8 GiB. It can be used in conjunction with "load_in_8bit" but not with "load-in-4bit" as far as I'm aware.
|
|
||||||
* **cpu-memory**: Similarly to the parameter above, you can also set a limit on the amount of CPU memory used. Whatever doesn't fit either in the GPU or the CPU will go to a disk cache, so to use this option you should also check the "disk" checkbox.
|
|
||||||
* **compute_dtype**: Used when "load-in-4bit" is checked. I recommend leaving the default value.
|
|
||||||
* **quant_type**: Used when "load-in-4bit" is checked. I recommend leaving the default value.
|
|
||||||
* **alpha_value**: Used to extend the context length of a model with a minor loss in quality. I have measured 1.75 to be optimal for 1.5x context, and 2.5 for 2x context. That is, with alpha = 2.5 you can make a model with 4096 context length go to 8192 context length.
|
|
||||||
* **rope_freq_base**: Originally another way to write "alpha_value", it ended up becoming a necessary parameter for some models like CodeLlama, which was fine-tuned with this set to 1000000 and hence needs to be loaded with it set to 1000000 as well.
|
|
||||||
* **compress_pos_emb**: The first and original context-length extension method, discovered by [kaiokendev](https://kaiokendev.github.io/til). When set to 2, the context length is doubled, 3 and it's tripled, etc. It should only be used for models that have been fine-tuned with this parameter set to different than 1. For models that have not been tuned to have greater context length, alpha_value will lead to a smaller accuracy loss.
|
|
||||||
* **cpu**: Loads the model in CPU mode using Pytorch. The model will be loaded in 32-bit precision, so a lot of RAM will be used. CPU inference with transformers is older than llama.cpp and it works, but it's a lot slower. Note: this parameter has a different interpretation in the llama.cpp loader (see below).
|
|
||||||
* **load-in-8bit**: Load the model in 8-bit precision using bitsandbytes. The 8-bit kernel in that library has been optimized for training and not inference, so load-in-8bit is slower than load-in-4bit (but more accurate).
|
|
||||||
* **bf16**: Use bfloat16 precision instead of float16 (the default). Only applies when quantization is not used.
|
|
||||||
* **auto-devices**: When checked, the backend will try to guess a reasonable value for "gpu-memory" to allow you to load a model with CPU offloading. I recommend just setting "gpu-memory" manually instead. This parameter is also needed for loading GPTQ models, in which case it needs to be checked before loading the model.
|
|
||||||
* **disk**: Enable disk offloading for layers that don't fit into the GPU and CPU combined.
|
|
||||||
* **load-in-4bit**: Load the model in 4-bit precision using bitsandbytes.
|
|
||||||
* **trust-remote-code**: Some models use custom Python code to load the model or the tokenizer. For such models, this option needs to be set. It doesn't download any remote content: all it does is execute the .py files that get downloaded with the model. Those files can potentially include malicious code; I have never seen it happen, but it is in principle possible.
|
|
||||||
* **no_use_fast**: Do not use the "fast" version of the tokenizer. Can usually be ignored; only check this if you can't load the tokenizer for your model otherwise.
|
|
||||||
* **use_flash_attention_2**: Set use_flash_attention_2=True while loading the model. Possibly useful for training.
|
|
||||||
* **disable_exllama**: Only applies when you are loading a GPTQ model through the transformers loader. It needs to be checked if you intend to train LoRAs with the model.
|
|
||||||
|
|
||||||
### ExLlamav2_HF
|
|
||||||
|
|
||||||
Loads: GPTQ and EXL2 models. EXL2 models usually have "EXL2" in the model name, while GPTQ models usually have GPTQ in the model name, or alternatively something like "-4bit-128g" in the name.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
* https://huggingface.co/turboderp/Llama2-70B-exl2
|
|
||||||
* https://huggingface.co/TheBloke/Llama-2-13B-chat-GPTQ
|
|
||||||
|
|
||||||
* **gpu-split**: If you have multiple GPUs, the amount of memory to allocate per GPU should be set in this field. Make sure to set a lower value for the first GPU, as that's where the cache is allocated.
|
|
||||||
* **max_seq_len**: The maximum sequence length for the model. In ExLlamaV2, the cache is preallocated, so the higher this value, the higher the VRAM. It is automatically set to the maximum sequence length for the model based on its metadata, but you may need to lower this value be able to fit the model into your GPU. After loading the model, the "Truncate the prompt up to this length" parameter under "Parameters" > "Generation" is automatically set to your chosen "max_seq_len" so that you don't have to set the same thing twice.
|
|
||||||
* **cfg-cache**: Creates a second cache to hold the CFG negative prompts. You need to set this if and only if you intend to use CFG in the "Parameters" > "Generation" tab. Checking this parameter doubles the cache VRAM usage.
|
|
||||||
* **no_flash_attn**: Disables flash attention. Otherwise, it is automatically used as long as the library is installed.
|
|
||||||
* **cache_8bit**: Create a 8-bit precision cache instead of a 16-bit one. This saves VRAM but increases perplexity (I don't know by how much).
|
|
||||||
* **cache_4bit**: Creates a Q4 cache using grouped quantization.
|
|
||||||
|
|
||||||
### ExLlamav2
|
|
||||||
|
|
||||||
The same as ExLlamav2_HF but using the internal samplers of ExLlamav2 instead of the ones in the Transformers library.
|
|
||||||
|
|
||||||
### AutoGPTQ
|
|
||||||
|
|
||||||
Loads: GPTQ models.
|
|
||||||
|
|
||||||
* **wbits**: For ancient models without proper metadata, sets the model precision in bits manually. Can usually be ignored.
|
|
||||||
* **groupsize**: For ancient models without proper metadata, sets the model group size manually. Can usually be ignored.
|
|
||||||
* **triton**: Only available on Linux. Necessary to use models with both act-order and groupsize simultaneously. Note that ExLlamaV2 can load these same models on Windows without triton.
|
|
||||||
* **no_inject_fused_attention**: Improves performance while increasing the VRAM usage.
|
|
||||||
* **no_inject_fused_mlp**: Similar to the previous parameter but for Triton only.
|
|
||||||
* **no_use_cuda_fp16**: On some systems, the performance can be very bad with this unset. Can usually be ignored.
|
|
||||||
* **desc_act**: For ancient models without proper metadata, sets the model "act-order" parameter manually. Can usually be ignored.
|
|
||||||
|
|
||||||
### llama.cpp
|
### llama.cpp
|
||||||
|
|
||||||
Loads: GGUF models. Note: GGML models have been deprecated and do not work anymore.
|
Loads: GGUF models. Note: GGML models have been deprecated and do not work anymore.
|
||||||
|
|
||||||
Example: https://huggingface.co/TheBloke/Llama-2-7b-Chat-GGUF
|
Example: https://huggingface.co/TheBloke/Llama-2-7b-Chat-GGUF
|
||||||
|
|
||||||
* **n-gpu-layers**: The number of layers to allocate to the GPU. If set to 0, only the CPU will be used. If you want to offload all layers, you can simply set this to the maximum value.
|
* **gpu_layers**: The number of layers to allocate to the GPU. If set to 0, only the CPU will be used. If you want to offload all layers, you can simply set this to the maximum value.
|
||||||
* **n_ctx**: Context length of the model. In llama.cpp, the cache is preallocated, so the higher this value, the higher the VRAM. It is automatically set to the maximum sequence length for the model based on the metadata inside the GGUF file, but you may need to lower this value be able to fit the model into your GPU. After loading the model, the "Truncate the prompt up to this length" parameter under "Parameters" > "Generation" is automatically set to your chosen "n_ctx" so that you don't have to set the same thing twice.
|
* **ctx_size**: Context length of the model. In llama.cpp, the cache is preallocated, so the higher this value, the higher the VRAM. It is automatically set to the maximum sequence length for the model based on the metadata inside the GGUF file, but you may need to lower this value to fit the model into your GPU. Set to 0 for automatic context size based on available memory. After loading the model, the "Truncate the prompt up to this length" parameter under "Parameters" > "Generation" is automatically set to your chosen "ctx_size" so that you don't have to set the same thing twice.
|
||||||
|
* **cache_type**: KV cache quantization type. Valid options: `fp16`, `q8_0`, `q4_0`. Lower quantization saves VRAM at the cost of some quality.
|
||||||
* **tensor_split**: For multi-gpu only. Sets the amount of memory to allocate per GPU as proportions. Not to be confused with other loaders where this is set in GB; here you can set something like `30,70` for 30%/70%.
|
* **tensor_split**: For multi-gpu only. Sets the amount of memory to allocate per GPU as proportions. Not to be confused with other loaders where this is set in GB; here you can set something like `30,70` for 30%/70%.
|
||||||
* **n_batch**: Batch size for prompt processing. Higher values are supposed to make generation faster, but I have never obtained any benefit from changing this value.
|
* **batch_size**: Maximum number of prompt tokens to batch together when calling llama_eval.
|
||||||
|
* **ubatch_size**: Physical maximum batch size for prompt processing.
|
||||||
* **threads**: Number of threads. Recommended value: your number of physical cores.
|
* **threads**: Number of threads. Recommended value: your number of physical cores.
|
||||||
* **threads_batch**: Number of threads for batch processing. Recommended value: your total number of cores (physical + virtual).
|
* **threads_batch**: Number of threads for batch processing. Recommended value: your total number of cores (physical + virtual).
|
||||||
* **tensorcores**: Use llama.cpp compiled with "tensor cores" support, which improves performance on NVIDIA RTX cards in most cases.
|
* **cpu_moe**: Force MoE expert layers to run on the CPU, keeping the rest on the GPU.
|
||||||
* **streamingllm**: Experimental feature to avoid re-evaluating the entire prompt when part of it is removed, for instance, when you hit the context length for the model in chat mode and an old message is removed.
|
* **extra_flags**: Extra flags to pass to llama-server. Format: `flag1=value1,flag2,flag3=value3`. Example: `override-tensor=exps=CPU`.
|
||||||
|
* **mmproj**: Path to the mmproj file for multimodal (vision) models. This enables image understanding capabilities.
|
||||||
|
* **streaming_llm**: Experimental feature to avoid re-evaluating the entire prompt when part of it is removed, for instance, when you hit the context length for the model in chat mode and an old message is removed.
|
||||||
* **cpu**: Force a version of llama.cpp compiled without GPU acceleration to be used. Can usually be ignored. Only set this if you want to use CPU only and llama.cpp doesn't work otherwise.
|
* **cpu**: Force a version of llama.cpp compiled without GPU acceleration to be used. Can usually be ignored. Only set this if you want to use CPU only and llama.cpp doesn't work otherwise.
|
||||||
* **no_mul_mat_q**: Disable the mul_mat_q kernel. This kernel usually improves generation speed significantly. This option to disable it is included in case it doesn't work on some system.
|
* **row_split**: Split the model by rows across GPUs. This may improve multi-gpu performance.
|
||||||
* **no-mmap**: Loads the model into memory at once, possibly preventing I/O operations later on at the cost of a longer load time.
|
* **no_kv_offload**: Do not offload the KV cache to the GPU. This saves VRAM but reduces performance.
|
||||||
* **mlock**: Force the system to keep the model in RAM rather than swapping or compressing (no idea what this means, never used it).
|
* **no_mmap**: Loads the model into memory at once, possibly preventing I/O operations later on at the cost of a longer load time.
|
||||||
|
* **mlock**: Force the system to keep the model in RAM rather than swapping or compressing.
|
||||||
* **numa**: May improve performance on certain multi-cpu systems.
|
* **numa**: May improve performance on certain multi-cpu systems.
|
||||||
|
|
||||||
### llamacpp_HF
|
### Transformers
|
||||||
|
|
||||||
The same as llama.cpp but with transformers samplers, and using the transformers tokenizer instead of the internal llama.cpp tokenizer.
|
Loads: full precision (16-bit or 32-bit) models, as well as bitsandbytes-quantized models. The repository usually has a clean name without GGUF or EXL3 in its name, and the model files are named `model.safetensors` or split into parts like `model-00001-of-00004.safetensors`.
|
||||||
|
|
||||||
To use it, you need to download a tokenizer. There are two options:
|
Example: [https://huggingface.co/lmsys/vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5).
|
||||||
|
|
||||||
1) Download `oobabooga/llama-tokenizer` under "Download model or LoRA". That's a default Llama tokenizer.
|
Full precision models use a ton of VRAM, so you will usually want to select the "load_in_4bit" and "use_double_quant" options to load the model in 4-bit precision using bitsandbytes.
|
||||||
2) Place your .gguf in a subfolder of `models/` along with these 3 files: `tokenizer.model`, `tokenizer_config.json`, and `special_tokens_map.json`. This takes precedence over Option 1.
|
|
||||||
|
|
||||||
It has an additional parameter:
|
Options:
|
||||||
|
|
||||||
* **logits_all**: Needs to be checked if you want to evaluate the perplexity of the llama.cpp model using the "Training" > "Perplexity evaluation" tab. Otherwise, leave it unchecked, as it makes prompt processing slower.
|
* **gpu_split**: When using multiple GPUs, sets the amount of VRAM in GB to allocate per GPU. Example: `20,7,7`.
|
||||||
|
* **cpu_memory**: Maximum CPU memory in GiB to use for CPU offloading via the accelerate library. Whatever doesn't fit in the GPU or CPU will go to a disk cache if the "disk" checkbox is enabled.
|
||||||
|
* **compute_dtype**: Used when "load_in_4bit" is checked. I recommend leaving the default value.
|
||||||
|
* **quant_type**: Used when "load_in_4bit" is checked. I recommend leaving the default value.
|
||||||
|
* **attn_implementation**: Choose the attention implementation. Valid options: `sdpa`, `eager`, `flash_attention_2`. The default (`sdpa`) works well in most cases; `flash_attention_2` may be useful for training.
|
||||||
|
* **cpu**: Loads the model in CPU mode using Pytorch. The model will be loaded in 32-bit precision, so a lot of RAM will be used. CPU inference with transformers is older than llama.cpp and it works, but it's a lot slower. Note: this parameter has a different interpretation in the llama.cpp loader (see above).
|
||||||
|
* **load_in_8bit**: Load the model in 8-bit precision using bitsandbytes. The 8-bit kernel in that library has been optimized for training and not inference, so load_in_8bit is slower than load_in_4bit (but more accurate).
|
||||||
|
* **bf16**: Use bfloat16 precision instead of float16 (the default). Only applies when quantization is not used.
|
||||||
|
* **disk**: Enable disk offloading for layers that don't fit into the GPU and CPU combined.
|
||||||
|
* **load_in_4bit**: Load the model in 4-bit precision using bitsandbytes.
|
||||||
|
* **use_double_quant**: Use double quantization with 4-bit loading for reduced memory usage.
|
||||||
|
* **trust-remote-code**: Some models use custom Python code to load the model or the tokenizer. For such models, this option needs to be set. It doesn't download any remote content: all it does is execute the .py files that get downloaded with the model. Those files can potentially include malicious code; I have never seen it happen, but it is in principle possible.
|
||||||
|
* **no_use_fast**: Do not use the "fast" version of the tokenizer. Can usually be ignored; only check this if you can't load the tokenizer for your model otherwise.
|
||||||
|
|
||||||
### AutoAWQ
|
### ExLlamav3_HF
|
||||||
|
|
||||||
Loads: AWQ models.
|
Loads: EXL3 models. These models usually have "EXL3" or "exl3" in the model name.
|
||||||
|
|
||||||
Example: https://huggingface.co/TheBloke/Phind-CodeLlama-34B-v2-AWQ
|
Uses the ExLlamaV3 backend with Transformers samplers.
|
||||||
|
|
||||||
The parameters are overall similar to AutoGPTQ.
|
* **ctx_size**: Context length of the model. The cache is preallocated, so the higher this value, the higher the VRAM. It is automatically set to the maximum sequence length for the model based on its metadata, but you may need to lower this value to fit the model into your GPU. After loading the model, the "Truncate the prompt up to this length" parameter under "Parameters" > "Generation" is automatically set to your chosen "ctx_size" so that you don't have to set the same thing twice.
|
||||||
|
* **cache_type**: KV cache quantization type. Valid options: `fp16`, `q2` to `q8`. You can also specify key and value bits separately, e.g. `q4_q8`. Lower quantization saves VRAM at the cost of some quality.
|
||||||
|
* **gpu_split**: Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: `20,7,7`.
|
||||||
|
* **cfg_cache**: Creates a second cache to hold the CFG negative prompts. You need to set this if and only if you intend to use CFG in the "Parameters" > "Generation" tab. Checking this parameter doubles the cache VRAM usage.
|
||||||
|
* **no_use_fast**: Do not use the "fast" version of the tokenizer.
|
||||||
|
* **enable_tp**: Enable Tensor Parallelism (TP) to split the model across GPUs.
|
||||||
|
* **tp_backend**: The backend for tensor parallelism. Valid options: `native`, `nccl`. Default: `native`.
|
||||||
|
|
||||||
|
### ExLlamav3
|
||||||
|
|
||||||
|
The same as ExLlamav3_HF but using the internal samplers of ExLlamaV3 instead of the ones in the Transformers library. Supports speculative decoding with a draft model. Also supports multimodal (vision) models natively.
|
||||||
|
|
||||||
|
* **ctx_size**: Same as ExLlamav3_HF.
|
||||||
|
* **cache_type**: Same as ExLlamav3_HF.
|
||||||
|
* **gpu_split**: Same as ExLlamav3_HF.
|
||||||
|
* **enable_tp**: Enable Tensor Parallelism (TP) to split the model across GPUs.
|
||||||
|
* **tp_backend**: The backend for tensor parallelism. Valid options: `native`, `nccl`. Default: `native`.
|
||||||
|
|
||||||
|
### TensorRT-LLM
|
||||||
|
|
||||||
|
Loads: TensorRT-LLM engine models. These are highly optimized models compiled specifically for NVIDIA GPUs.
|
||||||
|
|
||||||
|
* **ctx_size**: Context length of the model.
|
||||||
|
* **cpp_runner**: Use the ModelRunnerCpp runner, which is faster than the default ModelRunner but doesn't support streaming yet.
|
||||||
|
|
||||||
## Model dropdown
|
## Model dropdown
|
||||||
|
|
||||||
Here you can select a model to be loaded, refresh the list of available models (🔄), load/unload/reload the selected model, and save the settings for the model. The "settings" are the values in the input fields (checkboxes, sliders, dropdowns) below this dropdown.
|
Here you can select a model to be loaded, refresh the list of available models, load/unload/reload the selected model, and save the settings for the model. The "settings" are the values in the input fields (checkboxes, sliders, dropdowns) below this dropdown.
|
||||||
|
|
||||||
After saving, those settings will get restored whenever you select that model again in the dropdown menu.
|
After saving, those settings will get restored whenever you select that model again in the dropdown menu.
|
||||||
|
|
||||||
|
|
@ -115,14 +92,14 @@ If the **Autoload the model** checkbox is selected, the model will be loaded as
|
||||||
|
|
||||||
## LoRA dropdown
|
## LoRA dropdown
|
||||||
|
|
||||||
Used to apply LoRAs to the model. Note that LoRA support is not implemented for all loaders. Check this [page](https://github.com/oobabooga/text-generation-webui/wiki) for details.
|
Used to apply LoRAs to the model. Note that LoRA support is not implemented for all loaders. Check the [What Works](https://github.com/oobabooga/text-generation-webui/wiki/What-Works) page for details.
|
||||||
|
|
||||||
## Download model or LoRA
|
## Download model or LoRA
|
||||||
|
|
||||||
Here you can download a model or LoRA directly from the https://huggingface.co/ website.
|
Here you can download a model or LoRA directly from the https://huggingface.co/ website.
|
||||||
|
|
||||||
* Models will be saved to `text-generation-webui/models`.
|
* Models will be saved to `user_data/models`.
|
||||||
* LoRAs will be saved to `text-generation-webui/loras`.
|
* LoRAs will be saved to `user_data/loras`.
|
||||||
|
|
||||||
In the input field, you can enter either the Hugging Face username/model path (like `facebook/galactica-125m`) or the full model URL (like `https://huggingface.co/facebook/galactica-125m`). To specify a branch, add it at the end after a ":" character like this: `facebook/galactica-125m:main`.
|
In the input field, you can enter either the Hugging Face username/model path (like `facebook/galactica-125m`) or the full model URL (like `https://huggingface.co/facebook/galactica-125m`). To specify a branch, add it at the end after a ":" character like this: `facebook/galactica-125m:main`.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,139 +1,121 @@
|
||||||
## Training Your Own LoRAs
|
## Training Your Own LoRAs
|
||||||
|
|
||||||
The WebUI seeks to make training your own LoRAs as easy as possible. It comes down to just a few simple steps:
|
A LoRA is tied to a specific model architecture — a LoRA trained on Llama 3 8B won't work on Mistral 7B. Train on the exact model you plan to use.
|
||||||
|
|
||||||
### **Step 1**: Make a plan.
|
### Quick Start
|
||||||
- What base model do you want to use? The LoRA you make has to be matched up to a single architecture (eg LLaMA-13B) and cannot be transferred to others (eg LLaMA-7B, StableLM, etc. would all be different). Derivatives of the same model (eg Alpaca finetune of LLaMA-13B) might be transferrable, but even then it's best to train exactly on what you plan to use.
|
|
||||||
- What are you training it on? Do you want it to learn real information, a simple format, ...?
|
|
||||||
|
|
||||||
### **Step 2**: Gather a dataset.
|
1. Load your base model with the **Transformers** loader (no LoRAs loaded).
|
||||||
- If you use a dataset similar to the [Alpaca](https://github.com/gururise/AlpacaDataCleaned/blob/main/alpaca_data_cleaned.json) format, that is natively supported by the `Formatted Dataset` input in the WebUI, with premade formatter options.
|
2. Open the **Training** tab > **Train LoRA**.
|
||||||
- If you use a dataset that isn't matched to Alpaca's format, but uses the same basic JSON structure, you can make your own format file by copying `training/formats/alpaca-format.json` to a new file and [editing its content](#format-files).
|
3. Pick a dataset and configure parameters (see [below](#parameters)).
|
||||||
- If you can get the dataset into a simple text file, that works too! You can train using the `Raw text file` input option.
|
4. Click **Start LoRA Training** and monitor the [loss](#loss).
|
||||||
- This means you can for example just copy/paste a chatlog/documentation page/whatever you want, shove it in a plain text file, and train on it.
|
5. When done, load the LoRA from the **Models** tab and test it.
|
||||||
- If you use a structured dataset not in this format, you may have to find an external way to convert it - or open an issue to request native support.
|
|
||||||
|
|
||||||
### **Step 3**: Do the training.
|
### Resuming Training
|
||||||
- **3.1**: Load the WebUI, and your model.
|
|
||||||
- Make sure you don't have any LoRAs already loaded (unless you want to train for multi-LoRA usage).
|
|
||||||
- **3.2**: Open the `Training` tab at the top, `Train LoRA` sub-tab.
|
|
||||||
- **3.3**: Fill in the name of the LoRA, select your dataset in the dataset options.
|
|
||||||
- **3.4**: Select other parameters to your preference. See [parameters below](#parameters).
|
|
||||||
- **3.5**: click `Start LoRA Training`, and wait.
|
|
||||||
- It can take a few hours for a large dataset, or just a few minute if doing a small run.
|
|
||||||
- You may want to monitor your [loss value](#loss) while it goes.
|
|
||||||
|
|
||||||
### **Step 4**: Evaluate your results.
|
To resume from a checkpoint, use the same LoRA name and uncheck `Override Existing Files`. If checkpoints exist (from `Save every n steps`), training will automatically resume from the latest one with full optimizer and scheduler state preserved. Note that you cannot change the `Rank` of an already created LoRA.
|
||||||
- Load the LoRA under the Models Tab.
|
|
||||||
- You can go test-drive it on the `Text generation` tab, or you can use the `Perplexity evaluation` sub-tab of the `Training` tab.
|
|
||||||
- If you used the `Save every n steps` option, you can grab prior copies of the model from sub-folders within the LoRA model's folder and try them instead.
|
|
||||||
|
|
||||||
### **Step 5**: Re-run if you're unhappy.
|
You should also use `Copy parameters from` to restore the UI settings (learning rate, epochs, etc.) from the previous run, so that training continues with the same configuration.
|
||||||
- Make sure to unload the LoRA before training it.
|
|
||||||
- You can simply resume a prior run - use `Copy parameters from` to select your LoRA, and edit parameters. Note that you cannot change the `Rank` of an already created LoRA.
|
|
||||||
- If you want to resume from a checkpoint saved along the way, simply copy the contents of the checkpoint folder into the LoRA's folder.
|
|
||||||
- (Note: `adapter_model.bin` is the important file that holds the actual LoRA content).
|
|
||||||
- This will start Learning Rate and Steps back to the start. If you want to resume as if you were midway through, you can adjust your Learning Rate to the last reported LR in logs and reduce your epochs.
|
|
||||||
- Or, you can start over entirely if you prefer.
|
|
||||||
- If your model is producing corrupted outputs, you probably need to start over and use a lower Learning Rate.
|
|
||||||
- If your model isn't learning detailed information but you want it to, you might need to just run more epochs, or you might need a higher Rank.
|
|
||||||
- If your model is enforcing a format you didn't want, you may need to tweak your dataset, or start over and not train as far.
|
|
||||||
|
|
||||||
## Format Files
|
### Troubleshooting
|
||||||
|
|
||||||
If using JSON formatted datasets, they are presumed to be in the following approximate format:
|
- **Corrupted outputs**: Start over with a lower Learning Rate.
|
||||||
|
- **Not learning enough**: Run more epochs, or increase the Rank.
|
||||||
|
- **Unwanted formatting**: Tweak your dataset, or train for fewer steps.
|
||||||
|
|
||||||
|
## Instruction Templates
|
||||||
|
|
||||||
|
All instruction/chat training uses `apply_chat_template()` with Jinja2 templates. You have two options in the **Instruction Template** dropdown:
|
||||||
|
|
||||||
|
- **Chat Template**: Uses the model's built-in chat template from its tokenizer. Works with instruct/chat models that ship with a chat template (Llama 3, Qwen, Mistral, etc.).
|
||||||
|
- **Named template** (e.g. ChatML, Alpaca, Llama-v3, etc.): Loads a Jinja2 template from `user_data/instruction-templates/`. This is useful for base models that don't have a built-in template, or when you want to override the model's default template.
|
||||||
|
|
||||||
|
Both options are functionally identical — the only difference is where the Jinja2 template string comes from. In both cases:
|
||||||
|
- The dataset is tokenized via `apply_chat_template()`
|
||||||
|
- Labels are automatically masked so only assistant responses are trained on
|
||||||
|
- Multi-turn conversations are supported natively
|
||||||
|
- Special tokens are handled correctly by the template
|
||||||
|
|
||||||
|
The WebUI ships with 50+ templates in `user_data/instruction-templates/`. You can also add your own by creating a `.yaml` file with an `instruction_template` key containing a Jinja2 template string, or a plain `.jinja` file.
|
||||||
|
|
||||||
|
**Dataset formats:** Your JSON dataset can use either of these structures:
|
||||||
|
|
||||||
|
OpenAI messages format:
|
||||||
```json
|
```json
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"somekey": "somevalue",
|
"messages": [
|
||||||
"key2": "value2"
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
},
|
{"role": "user", "content": "What is Python?"},
|
||||||
{
|
{"role": "assistant", "content": "A programming language."},
|
||||||
// etc
|
{"role": "user", "content": "What's it used for?"},
|
||||||
|
{"role": "assistant", "content": "Web dev, data science, scripting, and more."}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
|
||||||
Where the keys (eg `somekey`, `key2` above) are standardized, and relatively consistent across the dataset, and the values (eg `somevalue`, `value2`) contain the content actually intended to be trained.
|
ShareGPT format (`conversations` key with `from`/`value` fields):
|
||||||
|
|
||||||
For Alpaca, the keys are `instruction`, `input`, and `output`, wherein `input` is sometimes blank.
|
|
||||||
|
|
||||||
A simple format file for Alpaca to be used as a chat bot is:
|
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
[
|
||||||
"instruction,output": "User: %instruction%\nAssistant: %output%",
|
{
|
||||||
"instruction,input,output": "User: %instruction%: %input%\nAssistant: %output%"
|
"conversations": [
|
||||||
}
|
{"from": "system", "value": "You are a helpful assistant."},
|
||||||
|
{"from": "human", "value": "What is Python?"},
|
||||||
|
{"from": "gpt", "value": "A programming language."},
|
||||||
|
{"from": "human", "value": "What's it used for?"},
|
||||||
|
{"from": "gpt", "value": "Web dev, data science, scripting, and more."}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
```
|
```
|
||||||
|
|
||||||
Note that the keys (eg `instruction,output`) are a comma-separated list of dataset keys, and the values are a simple string that use those keys with `%%`.
|
## Text Dataset
|
||||||
|
|
||||||
So for example if a dataset has `"instruction": "answer my question"`, then the format file's `User: %instruction%\n` will be automatically filled in as `User: answer my question\n`.
|
For pretraining-style training on raw text, use the **Text Dataset** tab. Your dataset should be a JSON file with one document per row, each with a `"text"` key:
|
||||||
|
|
||||||
If you have different sets of key inputs, you can make your own format file to match it. This format-file is designed to be as simple as possible to enable easy editing to match your needs.
|
```json
|
||||||
|
[
|
||||||
|
{"text": "First document content..."},
|
||||||
|
{"text": "Second document content..."}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
## Raw Text File Settings
|
This is the standard format used by most pretraining datasets (The Pile, RedPajama, etc.).
|
||||||
|
|
||||||
When using raw text files as your dataset, the text is automatically split into chunks based on your `Cutoff Length` you get a few basic options to configure them.
|
Each document is tokenized (with BOS token), concatenated into one long token sequence, and split into chunks of `Cutoff Length` tokens. The final chunk is padded if shorter than the cutoff length. When `Add EOS token` is enabled, an EOS token is appended after each document before concatenation, helping the model learn document boundaries.
|
||||||
- `Overlap Length` is how much to overlap chunks by. Overlapping chunks helps prevent the model from learning strange mid-sentence cuts, and instead learn continual sentences that flow from earlier text.
|
|
||||||
- `Prefer Newline Cut Length` sets a maximum distance in characters to shift the chunk cut towards newlines. Doing this helps prevent lines from starting or ending mid-sentence, preventing the model from learning to cut off sentences randomly.
|
- `Stride Length` controls the overlap between consecutive chunks in tokens. Set to 0 for non-overlapping chunks (the standard concatenate-and-split approach). Values like 256 or 512 create overlapping chunks that help the model learn context across chunk boundaries, at the cost of more training samples.
|
||||||
- `Hard Cut String` sets a string that indicates there must be a hard cut without overlap. This defaults to `\n\n\n`, meaning 3 newlines. No trained chunk will ever contain this string. This allows you to insert unrelated sections of text in the same text file, but still ensure the model won't be taught to randomly change the subject.
|
|
||||||
|
## Target Modules
|
||||||
|
|
||||||
|
By default, **Target all linear layers** is enabled. This uses peft's `all-linear` mode, which applies LoRA to every `nn.Linear` layer in the model except the output head (`lm_head`). It works for any model architecture.
|
||||||
|
|
||||||
|
If you uncheck it, you can manually select individual projection modules (`q_proj`, `k_proj`, `v_proj`, `o_proj`, `gate_proj`, `down_proj`, `up_proj`). Targeting fewer modules reduces VRAM usage and adapter size, but also reduces how much the model can learn. The default selection of `q_proj` + `v_proj` is the minimum for basic style/format training.
|
||||||
|
|
||||||
## Parameters
|
## Parameters
|
||||||
|
|
||||||
The basic purpose and function of each parameter is documented on-page in the WebUI, so read through them in the UI to understand your options.
|
Each parameter has a description in the UI. Below is guidance on the most important choices.
|
||||||
|
|
||||||
That said, here's a guide to the most important parameter choices you should consider:
|
|
||||||
|
|
||||||
### VRAM
|
### VRAM
|
||||||
|
|
||||||
- First, you must consider your VRAM availability.
|
VRAM usage during training is roughly similar to inference with ~1000 tokens of context. If you can run the model, you can probably train LoRAs with the default settings. If you run out of VRAM, reduce `Micro Batch Size` or `Cutoff Length`. Training 4-bit quantized models uses more VRAM — set `Micro Batch Size` to `1` to compensate.
|
||||||
- Generally, under default settings, VRAM usage for training with default parameters is very close to when generating text (with 1000+ tokens of context) (ie, if you can generate text, you can train LoRAs).
|
|
||||||
- Note: worse by default in the 4-bit monkeypatch currently. Reduce `Micro Batch Size` to `1` to restore this to expectations.
|
|
||||||
- If you have VRAM to spare, setting higher batch sizes will use more VRAM and get you better quality training in exchange.
|
|
||||||
- If you have large data, setting a higher cutoff length may be beneficial, but will cost significant VRAM. If you can spare some, set your batch size to `1` and see how high you can push your cutoff length.
|
|
||||||
- If you're low on VRAM, reducing batch size or cutoff length will of course improve that.
|
|
||||||
- Don't be afraid to just try it and see what happens. If it's too much, it will just error out, and you can lower settings and try again.
|
|
||||||
|
|
||||||
### Rank
|
### Rank
|
||||||
|
|
||||||
- Second, you want to consider the amount of learning you want.
|
Higher rank = more learning capacity = larger adapter = more VRAM. Use 4–8 for style/format, 128–256 to teach factual knowledge.
|
||||||
- For example, you may wish to just learn a dialogue format (as in the case of Alpaca) in which case setting a low `Rank` value (32 or lower) works great.
|
|
||||||
- Or, you might be training on project documentation you want the bot to understand and be able to understand questions about, in which case the higher the rank, the better.
|
|
||||||
- Generally, higher Rank = more precise learning = more total content learned = more VRAM usage while training.
|
|
||||||
|
|
||||||
### Learning Rate and Epochs
|
### Learning Rate and Epochs
|
||||||
|
|
||||||
- Third, how carefully you want it to be learned.
|
These control how aggressively the model learns and how many times it sees the data. Higher LR + fewer epochs = fast but rough. Lower LR + more epochs = slower but higher quality. The scheduler (default: cosine) decays the LR over the course of training — see [HuggingFace docs](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules#schedules) for graphs of each option.
|
||||||
- In other words, how okay or not you are with the model losing unrelated understandings.
|
|
||||||
- You can control this with 3 key settings: the Learning Rate, its scheduler, and your total epochs.
|
|
||||||
- The learning rate controls how much change is made to the model by each token it sees.
|
|
||||||
- It's in scientific notation normally, so for example `3e-4` means `3 * 10^-4` which is `0.0003`. The number after `e-` controls how many `0`s are in the number.
|
|
||||||
- Higher values let training run faster, but also are more likely to corrupt prior data in the model.
|
|
||||||
- You essentially have two variables to balance: the LR, and Epochs.
|
|
||||||
- If you make LR higher, you can set Epochs equally lower to match. High LR + low epochs = very fast, low quality training.
|
|
||||||
- If you make LR low, set epochs high. Low LR + high epochs = slow but high-quality training.
|
|
||||||
- The scheduler controls change-over-time as you train - it starts high, and then goes low. This helps balance getting data in, and having decent quality, at the same time.
|
|
||||||
- You can see graphs of the different scheduler options [in the HuggingFace docs here](https://moon-ci-docs.huggingface.co/docs/transformers/pr_1/en/main_classes/optimizer_schedules#transformers.SchedulerType)
|
|
||||||
|
|
||||||
## Loss
|
## Loss
|
||||||
|
|
||||||
When you're running training, the WebUI's console window will log reports that include, among other things, a numeric value named `Loss`. It will start as a high number, and gradually get lower and lower as it goes.
|
When you're running training, the WebUI's console window will log reports that include, among other things, a numeric value named `Loss`. It will start as a high number, and gradually get lower and lower as it goes.
|
||||||
|
|
||||||
"Loss" in the world of AI training theoretically means "how close is the model to perfect", with `0` meaning "absolutely perfect". This is calculated by measuring the difference between the model outputting exactly the text you're training it to output, and what it actually outputs.
|
Loss measures how far the model's predictions are from the training data, with `0` meaning a perfect match. It's calculated as the cross-entropy between the model's output distribution and the expected tokens.
|
||||||
|
|
||||||
In practice, a good LLM should have a very complex variable range of ideas running in its artificial head, so a loss of `0` would indicate that the model has broken and forgotten how to think about anything other than what you trained it on.
|
In practice, a loss of `0` means the model has overfit — it memorized the training data at the expense of its general capabilities.
|
||||||
|
|
||||||
So, in effect, Loss is a balancing game: you want to get it low enough that it understands your data, but high enough that it isn't forgetting everything else. Generally, if it goes below `1.0`, it's going to start forgetting its prior memories, and you should stop training. In some cases you may prefer to take it as low as `0.5` (if you want it to be very very predictable). Different goals have different needs, so don't be afraid to experiment and see what works best for you.
|
Loss is a balancing game: you want it low enough that the model learns your data, but not so low that it loses general knowledge. Generally, if it goes below `1.0`, overfitting is likely and you should stop training. In some cases you may want to go as low as `0.5` (if you need very predictable outputs). Different goals have different needs, so experiment and see what works best for you.
|
||||||
|
|
||||||
Note: if you see Loss start at or suddenly jump to exactly `0`, it is likely something has gone wrong in your training process (eg model corruption).
|
Note: if you see Loss start at or suddenly jump to exactly `0`, it is likely something has gone wrong in your training process (eg model corruption).
|
||||||
|
|
||||||
## Note: 4-Bit Monkeypatch
|
|
||||||
|
|
||||||
The [4-bit LoRA monkeypatch](GPTQ-models-(4-bit-mode).md#using-loras-in-4-bit-mode) works for training, but has side effects:
|
|
||||||
- VRAM usage is higher currently. You can reduce the `Micro Batch Size` to `1` to compensate.
|
|
||||||
- Models do funky things. LoRAs apply themselves, or refuse to apply, or spontaneously error out, or etc. It can be helpful to reload base model or restart the WebUI between training/usage to minimize chances of anything going haywire.
|
|
||||||
- Loading or working with multiple LoRAs at the same time doesn't currently work.
|
|
||||||
- Generally, recognize and treat the monkeypatch as the dirty temporary hack it is - it works, but isn't very stable. It will get better in time when everything is merged upstream for full official support.
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,15 @@
|
||||||
Here you can restart the UI with new settings.
|
Here you can restart the UI with new settings.
|
||||||
|
|
||||||
* **Available extensions**: shows a list of extensions available under `text-generation-webui/extensions`.
|
## Settings
|
||||||
|
|
||||||
|
* **Toggle light/dark theme**: switches between light and dark mode.
|
||||||
|
* **Show two columns in the Notebook tab**: toggles between the two-column Default layout and the single-column Notebook layout.
|
||||||
|
* **Turn long pasted text into attachments in the Chat tab**: when enabled, long pasted text is automatically converted into file attachments.
|
||||||
|
* **Include attachments/search results from previous messages in the chat prompt**: when enabled, attachments and web search results from earlier messages are included in subsequent prompts.
|
||||||
|
|
||||||
|
## Extensions & flags
|
||||||
|
|
||||||
|
* **Available extensions**: shows a list of extensions available under `text-generation-webui/extensions` and `text-generation-webui/user_data/extensions`. Note that some of these extensions may require manually installing Python requirements through the command: `pip install -r extensions/extension_name/requirements.txt`.
|
||||||
* **Boolean command-line flags**: shows command-line flags of bool (true/false) type.
|
* **Boolean command-line flags**: shows command-line flags of bool (true/false) type.
|
||||||
|
|
||||||
After selecting your desired flags and extensions, you can restart the UI by clicking on **Apply flags/extensions and restart**.
|
After selecting your desired flags and extensions, you can restart the UI by clicking on **Apply flags/extensions and restart**.
|
||||||
|
|
@ -27,6 +36,6 @@ If you used the one-click installer, this command should be executed in the term
|
||||||
|
|
||||||
## Saving UI defaults
|
## Saving UI defaults
|
||||||
|
|
||||||
The **Save UI defaults to settings.yaml** button gathers the visible values in the UI and saves them to settings.yaml so that your settings will persist across multiple restarts of the UI.
|
The **Save extensions settings to user_data/settings.yaml** button gathers the visible values in the UI and saves them to `user_data/settings.yaml` so that your settings will persist across multiple restarts of the UI.
|
||||||
|
|
||||||
Note that preset parameters like temperature are not individually saved, so you need to first save your preset and select it in the preset menu before saving the defaults.
|
Note that preset parameters like temperature are not individually saved, so you need to first save your preset and select it in the preset menu before saving the defaults.
|
||||||
|
|
|
||||||
|
|
@ -21,17 +21,19 @@ If you create an extension, you are welcome to host it in a GitHub repository an
|
||||||
|Extension|Description|
|
|Extension|Description|
|
||||||
|---------|-----------|
|
|---------|-----------|
|
||||||
|[openai](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/openai)| Creates an API that mimics the OpenAI API and can be used as a drop-in replacement. |
|
|[openai](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/openai)| Creates an API that mimics the OpenAI API and can be used as a drop-in replacement. |
|
||||||
|[multimodal](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal) | Adds multimodality support (text+images). For a detailed description see [README.md](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal/README.md) in the extension directory. |
|
|[superboogav2](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/superboogav2)| Enhanced RAG extension with support for PDF, DOCX, and PPTX files. |
|
||||||
|[google_translate](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/google_translate)| Automatically translates inputs and outputs using Google Translate.|
|
|[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. |
|
||||||
|
|[coqui_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/coqui_tts)| Text-to-speech extension using Coqui XTTS v2. |
|
||||||
|[silero_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/silero_tts)| Text-to-speech extension using [Silero](https://github.com/snakers4/silero-models). When used in chat mode, responses are replaced with an audio widget. |
|
|[silero_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/silero_tts)| Text-to-speech extension using [Silero](https://github.com/snakers4/silero-models). When used in chat mode, responses are replaced with an audio widget. |
|
||||||
|[whisper_stt](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/whisper_stt)| Allows you to enter your inputs in chat mode using your microphone. |
|
|[whisper_stt](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/whisper_stt)| Allows you to enter your inputs in chat mode using your microphone. |
|
||||||
|[sd_api_pictures](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/sd_api_pictures)| Allows you to request pictures from the bot in chat mode, which will be generated using the AUTOMATIC1111 Stable Diffusion API. See examples [here](https://github.com/oobabooga/text-generation-webui/pull/309). |
|
|
||||||
|[character_bias](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/character_bias)| Just a very simple example that adds a hidden string at the beginning of the bot's reply in chat mode. |
|
|
||||||
|[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. |
|
|
||||||
|[gallery](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/gallery/)| Creates a gallery with the chat characters and their pictures. |
|
|
||||||
|[superbooga](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/superbooga)| An extension that uses ChromaDB to create an arbitrarily large pseudocontext, taking as input text files, URLs, or pasted text. Based on https://github.com/kaiokendev/superbig. |
|
|
||||||
|[ngrok](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/ngrok)| Allows you to access the web UI remotely using the ngrok reverse tunnel service (free). It's an alternative to the built-in Gradio `--share` feature. |
|
|
||||||
|[perplexity_colors](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/perplexity_colors)| Colors each token in the output text by its associated probability, as derived from the model logits. |
|
|[perplexity_colors](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/perplexity_colors)| Colors each token in the output text by its associated probability, as derived from the model logits. |
|
||||||
|
|[google_translate](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/google_translate)| Automatically translates inputs and outputs using Google Translate.|
|
||||||
|
|[gallery](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/gallery/)| Creates a gallery with the chat characters and their pictures. |
|
||||||
|
|[sd_api_pictures](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/sd_api_pictures)| Allows you to request pictures from the bot in chat mode, which will be generated using the AUTOMATIC1111 Stable Diffusion API. See examples [here](https://github.com/oobabooga/text-generation-webui/pull/309). |
|
||||||
|
|[long_replies](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/long_replies)| Forces longer replies by suppressing early newlines in the model output. |
|
||||||
|
|[ngrok](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/ngrok)| Allows you to access the web UI remotely using the ngrok reverse tunnel service (free). It's an alternative to the built-in Gradio `--share` feature. |
|
||||||
|
|[superbooga](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/superbooga)| An extension that uses ChromaDB to create an arbitrarily large pseudocontext, taking as input text files, URLs, or pasted text. Based on https://github.com/kaiokendev/superbig. |
|
||||||
|
|[character_bias](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/character_bias)| Just a very simple example that adds a hidden string at the beginning of the bot's reply in chat mode. |
|
||||||
|
|
||||||
## How to write an extension
|
## How to write an extension
|
||||||
|
|
||||||
|
|
@ -51,8 +53,8 @@ The extensions framework is based on special functions and variables that you ca
|
||||||
| `def history_modifier(history)` | Modifies the chat history before the text generation in chat mode begins. |
|
| `def history_modifier(history)` | Modifies the chat history before the text generation in chat mode begins. |
|
||||||
| `def custom_generate_reply(...)` | Overrides the main text generation function. |
|
| `def custom_generate_reply(...)` | Overrides the main text generation function. |
|
||||||
| `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. |
|
| `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. |
|
||||||
| `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See the `multimodal` extension for an example. |
|
| `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See the `example` extension for a template. |
|
||||||
| `def custom_tokenized_length(prompt)` | Used in conjunction with `tokenizer_modifier`, returns the length in tokens of `prompt`. See the `multimodal` extension for an example. |
|
| `def custom_tokenized_length(prompt)` | Used in conjunction with `tokenizer_modifier`, returns the length in tokens of `prompt`. See the `example` extension for a template. |
|
||||||
|
|
||||||
Additionally, you can define a special `params` dictionary. In it, the `display_name` key is used to define the displayed name of the extension in the UI, and the `is_tab` key is used to define whether the extension should appear in a new tab. By default, extensions appear at the bottom of the "Text generation" tab.
|
Additionally, you can define a special `params` dictionary. In it, the `display_name` key is used to define the displayed name of the extension in the UI, and the `is_tab` key is used to define whether the extension should appear in a new tab. By default, extensions appear at the bottom of the "Text generation" tab.
|
||||||
|
|
||||||
|
|
@ -186,7 +188,7 @@ def bot_prefix_modifier(string, state):
|
||||||
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
|
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
|
||||||
"""
|
"""
|
||||||
Modifies the input ids and embeds.
|
Modifies the input ids and embeds.
|
||||||
Used by the multimodal extension to put image embeddings in the prompt.
|
Modifies the input ids and embeds fed to the model.
|
||||||
Only used by loaders that use the transformers library for sampling.
|
Only used by loaders that use the transformers library for sampling.
|
||||||
"""
|
"""
|
||||||
return prompt, input_ids, input_embeds
|
return prompt, input_ids, input_embeds
|
||||||
|
|
|
||||||
|
|
@ -13,29 +13,6 @@ Source: https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/1126
|
||||||
|
|
||||||
This file will be automatically detected the next time you start the web UI.
|
This file will be automatically detected the next time you start the web UI.
|
||||||
|
|
||||||
## DeepSpeed
|
|
||||||
|
|
||||||
`DeepSpeed ZeRO-3` is an alternative offloading strategy for full-precision (16-bit) transformers models.
|
|
||||||
|
|
||||||
With this, I have been able to load a 6b model (GPT-J 6B) with less than 6GB of VRAM. The speed of text generation is very decent and much better than what would be accomplished with `--auto-devices --gpu-memory 6`.
|
|
||||||
|
|
||||||
As far as I know, DeepSpeed is only available for Linux at the moment.
|
|
||||||
|
|
||||||
### How to use it
|
|
||||||
|
|
||||||
1. Install DeepSpeed:
|
|
||||||
|
|
||||||
```
|
|
||||||
conda install -c conda-forge mpi4py mpich
|
|
||||||
pip install -U deepspeed
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Start the web UI replacing `python` with `deepspeed --num_gpus=1` and adding the `--deepspeed` flag. Example:
|
|
||||||
|
|
||||||
```
|
|
||||||
deepspeed --num_gpus=1 server.py --deepspeed --chat --model gpt-j-6B
|
|
||||||
```
|
|
||||||
|
|
||||||
## Miscellaneous info
|
## Miscellaneous info
|
||||||
|
|
||||||
### You can train LoRAs in CPU mode
|
### You can train LoRAs in CPU mode
|
||||||
|
|
|
||||||
|
|
@ -1,208 +1,52 @@
|
||||||
Docker Compose is a way of installing and launching the web UI in an isolated Ubuntu image using only a few commands.
|
Docker Compose is a way of installing and launching the web UI in an isolated Ubuntu image using only a few commands.
|
||||||
|
|
||||||
## Installing Docker Compose
|
## Prerequisites
|
||||||
|
|
||||||
In order to create the image as described in the main README, you must have Docker Compose installed (2.17 or higher is recommended):
|
You need Docker Compose v2.17 or higher:
|
||||||
|
|
||||||
```
|
```
|
||||||
~$ docker compose version
|
~$ docker compose version
|
||||||
Docker Compose version v2.21.0
|
Docker Compose version v2.21.0
|
||||||
```
|
```
|
||||||
|
|
||||||
The installation instructions for various Linux distributions can be found here:
|
Installation instructions: https://docs.docker.com/engine/install/
|
||||||
|
|
||||||
https://docs.docker.com/engine/install/ubuntu/#install-using-the-repository
|
For NVIDIA GPUs, you also need the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
|
||||||
|
|
||||||
## Launching the image
|
## Quick start
|
||||||
|
|
||||||
Use these commands to launch the image:
|
There are four Docker variants available under `docker/`:
|
||||||
|
|
||||||
```
|
| Directory | GPU | Notes |
|
||||||
cd text-generation-webui
|
|-----------|-----|-------|
|
||||||
ln -s docker/{nvidia/Dockerfile,nvidia/docker-compose.yml,.dockerignore} .
|
| `docker/nvidia` | NVIDIA | Requires NVIDIA Container Toolkit |
|
||||||
cp docker/.env.example .env
|
| `docker/amd` | AMD | Requires ROCm-compatible GPU |
|
||||||
# Edit .env and set TORCH_CUDA_ARCH_LIST based on your GPU model
|
| `docker/intel` | Intel Arc | Beta support |
|
||||||
|
| `docker/cpu` | None | CPU-only inference |
|
||||||
|
|
||||||
|
To launch (using NVIDIA as an example):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd text-generation-webui/docker/nvidia
|
||||||
|
cp ../.env.example .env
|
||||||
|
# Optionally edit .env to customize ports, TORCH_CUDA_ARCH_LIST, etc.
|
||||||
docker compose up --build
|
docker compose up --build
|
||||||
```
|
```
|
||||||
|
|
||||||
## More detailed installation instructions
|
The web UI will be available at `http://localhost:7860`.
|
||||||
|
|
||||||
* [Docker Compose installation instructions](#docker-compose-installation-instructions)
|
## User data
|
||||||
* [Repository with additional Docker files](#dedicated-docker-repository)
|
|
||||||
|
|
||||||
By [@loeken](https://github.com/loeken).
|
Create a `user_data/` directory next to the `docker-compose.yml` to persist your models, characters, presets, and settings between container rebuilds:
|
||||||
|
|
||||||
- [Ubuntu 22.04](#ubuntu-2204)
|
|
||||||
- [0. youtube video](#0-youtube-video)
|
|
||||||
- [1. update the drivers](#1-update-the-drivers)
|
|
||||||
- [2. reboot](#2-reboot)
|
|
||||||
- [3. install docker](#3-install-docker)
|
|
||||||
- [4. docker \& container toolkit](#4-docker--container-toolkit)
|
|
||||||
- [5. clone the repo](#5-clone-the-repo)
|
|
||||||
- [6. prepare models](#6-prepare-models)
|
|
||||||
- [7. prepare .env file](#7-prepare-env-file)
|
|
||||||
- [8. startup docker container](#8-startup-docker-container)
|
|
||||||
- [Manjaro](#manjaro)
|
|
||||||
- [update the drivers](#update-the-drivers)
|
|
||||||
- [reboot](#reboot)
|
|
||||||
- [docker \& container toolkit](#docker--container-toolkit)
|
|
||||||
- [continue with ubuntu task](#continue-with-ubuntu-task)
|
|
||||||
- [Windows](#windows)
|
|
||||||
- [0. youtube video](#0-youtube-video-1)
|
|
||||||
- [1. choco package manager](#1-choco-package-manager)
|
|
||||||
- [2. install drivers/dependencies](#2-install-driversdependencies)
|
|
||||||
- [3. install wsl](#3-install-wsl)
|
|
||||||
- [4. reboot](#4-reboot)
|
|
||||||
- [5. git clone \&\& startup](#5-git-clone--startup)
|
|
||||||
- [6. prepare models](#6-prepare-models-1)
|
|
||||||
- [7. startup](#7-startup)
|
|
||||||
- [notes](#notes)
|
|
||||||
|
|
||||||
### Ubuntu 22.04
|
|
||||||
|
|
||||||
#### 0. youtube video
|
|
||||||
A video walking you through the setup can be found here:
|
|
||||||
|
|
||||||
[](https://www.youtube.com/watch?v=ELkKWYh8qOk)
|
|
||||||
|
|
||||||
|
|
||||||
#### 1. update the drivers
|
|
||||||
in the the “software updater” update drivers to the last version of the prop driver.
|
|
||||||
|
|
||||||
#### 2. reboot
|
|
||||||
to switch using to new driver
|
|
||||||
|
|
||||||
#### 3. install docker
|
|
||||||
```bash
|
```bash
|
||||||
sudo apt update
|
mkdir -p user_data
|
||||||
sudo apt-get install curl
|
|
||||||
sudo mkdir -m 0755 -p /etc/apt/keyrings
|
|
||||||
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg
|
|
||||||
echo \
|
|
||||||
"deb [arch="$(dpkg --print-architecture)" signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \
|
|
||||||
"$(. /etc/os-release && echo "$VERSION_CODENAME")" stable" | \
|
|
||||||
sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
|
|
||||||
sudo apt update
|
|
||||||
sudo apt-get install docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin docker-compose -y
|
|
||||||
sudo usermod -aG docker $USER
|
|
||||||
newgrp docker
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 4. docker & container toolkit
|
This directory is mounted into the container at runtime. You can place a `CMD_FLAGS.txt` inside it to pass persistent flags to the web UI (e.g., `--api`).
|
||||||
```bash
|
|
||||||
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg
|
|
||||||
echo "deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://nvidia.github.io/libnvidia-container/stable/ubuntu22.04/amd64 /" | \
|
|
||||||
sudo tee /etc/apt/sources.list.d/nvidia.list > /dev/null
|
|
||||||
sudo apt update
|
|
||||||
sudo apt install nvidia-docker2 nvidia-container-runtime -y
|
|
||||||
sudo systemctl restart docker
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 5. clone the repo
|
Models can be downloaded through the web UI's “Model” tab once it's running, and they will be saved to `user_data/models/`.
|
||||||
```
|
|
||||||
git clone https://github.com/oobabooga/text-generation-webui
|
|
||||||
cd text-generation-webui
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 6. prepare models
|
|
||||||
download and place the models inside the models folder. tested with:
|
|
||||||
|
|
||||||
4bit
|
|
||||||
https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1483891617
|
|
||||||
https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1483941105
|
|
||||||
|
|
||||||
8bit:
|
|
||||||
https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1484235789
|
|
||||||
|
|
||||||
#### 7. prepare .env file
|
|
||||||
edit .env values to your needs.
|
|
||||||
```bash
|
|
||||||
cp .env.example .env
|
|
||||||
nano .env
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 8. startup docker container
|
|
||||||
```bash
|
|
||||||
docker compose up --build
|
|
||||||
```
|
|
||||||
|
|
||||||
### Manjaro
|
|
||||||
manjaro/arch is similar to ubuntu just the dependency installation is more convenient
|
|
||||||
|
|
||||||
#### update the drivers
|
|
||||||
```bash
|
|
||||||
sudo mhwd -a pci nonfree 0300
|
|
||||||
```
|
|
||||||
#### reboot
|
|
||||||
```bash
|
|
||||||
reboot
|
|
||||||
```
|
|
||||||
#### docker & container toolkit
|
|
||||||
```bash
|
|
||||||
yay -S docker docker-compose buildkit gcc nvidia-docker
|
|
||||||
sudo usermod -aG docker $USER
|
|
||||||
newgrp docker
|
|
||||||
sudo systemctl restart docker # required by nvidia-container-runtime
|
|
||||||
```
|
|
||||||
|
|
||||||
#### continue with ubuntu task
|
|
||||||
continue at [5. clone the repo](#5-clone-the-repo)
|
|
||||||
|
|
||||||
### Windows
|
|
||||||
#### 0. youtube video
|
|
||||||
A video walking you through the setup can be found here:
|
|
||||||
[](https://www.youtube.com/watch?v=ejH4w5b5kFQ)
|
|
||||||
|
|
||||||
#### 1. choco package manager
|
|
||||||
install package manager (https://chocolatey.org/ )
|
|
||||||
```
|
|
||||||
Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1'))
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 2. install drivers/dependencies
|
|
||||||
```
|
|
||||||
choco install nvidia-display-driver cuda git docker-desktop
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 3. install wsl
|
|
||||||
wsl --install
|
|
||||||
|
|
||||||
#### 4. reboot
|
|
||||||
after reboot enter username/password in wsl
|
|
||||||
|
|
||||||
#### 5. git clone && startup
|
|
||||||
clone the repo and edit .env values to your needs.
|
|
||||||
```
|
|
||||||
cd Desktop
|
|
||||||
git clone https://github.com/oobabooga/text-generation-webui
|
|
||||||
cd text-generation-webui
|
|
||||||
COPY .env.example .env
|
|
||||||
notepad .env
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 6. prepare models
|
|
||||||
download and place the models inside the models folder. tested with:
|
|
||||||
|
|
||||||
4bit https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1483891617 https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1483941105
|
|
||||||
|
|
||||||
8bit: https://github.com/oobabooga/text-generation-webui/pull/530#issuecomment-1484235789
|
|
||||||
|
|
||||||
#### 7. startup
|
|
||||||
```
|
|
||||||
docker compose up
|
|
||||||
```
|
|
||||||
|
|
||||||
### notes
|
|
||||||
|
|
||||||
on older ubuntus you can manually install the docker compose plugin like this:
|
|
||||||
```
|
|
||||||
DOCKER_CONFIG=${DOCKER_CONFIG:-$HOME/.docker}
|
|
||||||
mkdir -p $DOCKER_CONFIG/cli-plugins
|
|
||||||
curl -SL https://github.com/docker/compose/releases/download/v2.17.2/docker-compose-linux-x86_64 -o $DOCKER_CONFIG/cli-plugins/docker-compose
|
|
||||||
chmod +x $DOCKER_CONFIG/cli-plugins/docker-compose
|
|
||||||
export PATH="$HOME/.docker/cli-plugins:$PATH"
|
|
||||||
```
|
|
||||||
|
|
||||||
## Dedicated docker repository
|
## Dedicated docker repository
|
||||||
|
|
||||||
An external repository maintains a docker wrapper for this project as well as several pre-configured 'one-click' `docker compose` variants (e.g., updated branches of GPTQ). It can be found at: [Atinoda/text-generation-webui-docker](https://github.com/Atinoda/text-generation-webui-docker).
|
An external repository maintains a docker wrapper for this project as well as several pre-configured 'one-click' `docker compose` variants. It can be found at: [Atinoda/text-generation-webui-docker](https://github.com/Atinoda/text-generation-webui-docker).
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,25 @@
|
||||||
## Using an AMD GPU in Linux
|
## Using an AMD GPU in Linux
|
||||||
|
|
||||||
Requires ROCm SDK 5.4.2 or 5.4.3 to be installed. Some systems may also
|
Requires ROCm 6.4 to be installed.
|
||||||
need:
|
|
||||||
|
### Option 1: One-click installer
|
||||||
|
|
||||||
|
The one-click installer (`start_linux.sh`) automatically detects AMD GPUs. When prompted, select the AMD option, or set the `GPU_CHOICE` environment variable before running:
|
||||||
|
|
||||||
```
|
```
|
||||||
sudo apt-get install libstdc++-12-dev
|
GPU_CHOICE=B ./start_linux.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
Edit the "one_click.py" script using a text editor and un-comment and
|
### Option 2: Manual conda install
|
||||||
modify the lines near the top of the script according to your setup. In
|
|
||||||
particular, modify the `os.environ["ROCM_PATH"] = '/opt/rocm'` line to
|
Follow the manual conda installation instructions in the README, using the AMD PyTorch command:
|
||||||
point to your ROCm installation.
|
|
||||||
|
```
|
||||||
|
pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/rocm6.4
|
||||||
|
```
|
||||||
|
|
||||||
|
Then install the project requirements with the AMD requirements file:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -r requirements/full/requirements_amd.txt
|
||||||
|
```
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ curl http://127.0.0.1:5000/v1/completions \
|
||||||
|
|
||||||
#### Chat completions
|
#### Chat completions
|
||||||
|
|
||||||
Works best with instruction-following models. If the "instruction_template" variable is not provided, it will be guessed automatically based on the model name using the regex patterns in `models/config.yaml`.
|
Works best with instruction-following models. If the "instruction_template" variable is not provided, it will be guessed automatically based on the model name using the regex patterns in `user_data/models/config.yaml`.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://127.0.0.1:5000/v1/chat/completions \
|
curl http://127.0.0.1:5000/v1/chat/completions \
|
||||||
|
|
@ -338,6 +338,35 @@ for event in client.events():
|
||||||
print()
|
print()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Python parallel requests example
|
||||||
|
|
||||||
|
The API supports handling multiple requests in parallel. For ExLlamaV3, this works out of the box. For llama.cpp, you need to pass `--parallel N` to set the number of concurrent slots.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import concurrent.futures
|
||||||
|
import requests
|
||||||
|
|
||||||
|
url = "http://127.0.0.1:5000/v1/chat/completions"
|
||||||
|
prompts = [
|
||||||
|
"Write a haiku about the ocean.",
|
||||||
|
"Explain quantum computing in simple terms.",
|
||||||
|
"Tell me a joke about programmers.",
|
||||||
|
]
|
||||||
|
|
||||||
|
def send_request(prompt):
|
||||||
|
response = requests.post(url, json={
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"max_tokens": 200,
|
||||||
|
})
|
||||||
|
return response.json()["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
results = list(executor.map(send_request, prompts))
|
||||||
|
|
||||||
|
for prompt, result in zip(prompts, results):
|
||||||
|
print(f"Q: {prompt}\nA: {result}\n")
|
||||||
|
```
|
||||||
|
|
||||||
#### Python example with API key
|
#### Python example with API key
|
||||||
|
|
||||||
Replace
|
Replace
|
||||||
|
|
@ -359,83 +388,93 @@ headers = {
|
||||||
|
|
||||||
in any of the examples above.
|
in any of the examples above.
|
||||||
|
|
||||||
#### Tool/Function Calling Example
|
#### Tool/Function calling
|
||||||
|
|
||||||
You need to use a model with tools support. The prompt will be automatically formatted using the model's Jinja2 template.
|
Use a model with tool calling support (Qwen, Mistral, GPT-OSS, etc). Tools are passed via the `tools` parameter and the prompt is automatically formatted using the model's Jinja2 template.
|
||||||
|
|
||||||
Request:
|
When the model decides to call a tool, the response will have `finish_reason: "tool_calls"` and a `tool_calls` array with structured function names and arguments. You then execute the tool, send the result back as a `role: "tool"` message, and continue until the model responds with `finish_reason: "stop"`.
|
||||||
|
|
||||||
```
|
Some models call multiple tools in parallel (Qwen, Mistral), while others call one at a time (GPT-OSS). The loop below handles both styles.
|
||||||
curl http://127.0.0.1:5000/v1/chat/completions \
|
|
||||||
-H "Content-Type: application/json" \
|
```python
|
||||||
-d '{
|
import json
|
||||||
"messages": [
|
import requests
|
||||||
{
|
|
||||||
"role": "system",
|
url = "http://127.0.0.1:5000/v1/chat/completions"
|
||||||
"content": "You are a helpful assistant."
|
|
||||||
},
|
# Define your tools
|
||||||
{
|
tools = [
|
||||||
"role": "user",
|
|
||||||
"content": "What time is it currently in New York City?"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"tools": [
|
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "get_current_time",
|
"name": "get_weather",
|
||||||
"description": "Get current time in a specific timezones",
|
"description": "Get the current weather for a given location",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["timezone"],
|
|
||||||
"properties": {
|
"properties": {
|
||||||
"timezone": {
|
"location": {"type": "string", "description": "City name"},
|
||||||
"type": "string",
|
},
|
||||||
"description": "IANA timezone name (e.g., America/New_York, Europe/London). Use Europe/Berlin as local timezone if no timezone provided by the user."
|
"required": ["location"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
Sample response:
|
|
||||||
|
|
||||||
```
|
|
||||||
{
|
|
||||||
"id": "chatcmpl-1746532051477984256",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": 1746532051,
|
|
||||||
"model": "qwen2.5-coder-14b-instruct-q4_k_m.gguf",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"finish_reason": "tool_calls",
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "```xml\n<function>\n{\n \"name\": \"get_current_time\",\n \"arguments\": {\n \"timezone\": \"America/New_York\"\n }\n}\n</function>\n```"
|
|
||||||
},
|
},
|
||||||
"tool_calls": [
|
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "get_current_time",
|
"name": "get_time",
|
||||||
"arguments": "{\"timezone\": \"America/New_York\"}"
|
"description": "Get the current time in a given timezone",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"timezone": {"type": "string", "description": "IANA timezone string"},
|
||||||
},
|
},
|
||||||
"id": "call_52ij07mh",
|
"required": ["timezone"]
|
||||||
"index": "0"
|
|
||||||
}
|
}
|
||||||
]
|
|
||||||
}
|
}
|
||||||
],
|
},
|
||||||
"usage": {
|
]
|
||||||
"prompt_tokens": 224,
|
|
||||||
"completion_tokens": 38,
|
|
||||||
"total_tokens": 262
|
def execute_tool(name, arguments):
|
||||||
}
|
"""Replace this with your actual tool implementations."""
|
||||||
}
|
if name == "get_weather":
|
||||||
|
return {"temperature": 22, "condition": "sunny", "humidity": 45}
|
||||||
|
elif name == "get_time":
|
||||||
|
return {"time": "2:30 PM", "timezone": "JST"}
|
||||||
|
return {"error": f"Unknown tool: {name}"}
|
||||||
|
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "What time is it in Tokyo and what's the weather like there?"}]
|
||||||
|
|
||||||
|
# Tool-calling loop: keep going until the model gives a final answer
|
||||||
|
for _ in range(10):
|
||||||
|
response = requests.post(url, json={"messages": messages, "tools": tools}).json()
|
||||||
|
choice = response["choices"][0]
|
||||||
|
|
||||||
|
if choice["finish_reason"] == "tool_calls":
|
||||||
|
# Add the assistant's response (with tool_calls) to history
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": choice["message"]["content"],
|
||||||
|
"tool_calls": choice["message"]["tool_calls"],
|
||||||
|
})
|
||||||
|
|
||||||
|
# Execute each tool and add results to history
|
||||||
|
for tool_call in choice["message"]["tool_calls"]:
|
||||||
|
name = tool_call["function"]["name"]
|
||||||
|
arguments = json.loads(tool_call["function"]["arguments"])
|
||||||
|
result = execute_tool(name, arguments)
|
||||||
|
|
||||||
|
print(f"Tool call: {name}({arguments}) => {result}")
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_call["id"],
|
||||||
|
"content": json.dumps(result),
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# Final answer
|
||||||
|
print(f"\nAssistant: {choice['message']['content']}")
|
||||||
|
break
|
||||||
```
|
```
|
||||||
|
|
||||||
### Environment variables
|
### Environment variables
|
||||||
|
|
@ -476,51 +515,45 @@ OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111
|
||||||
OPENAI_API_BASE=http://127.0.0.1:5000/v1
|
OPENAI_API_BASE=http://127.0.0.1:5000/v1
|
||||||
```
|
```
|
||||||
|
|
||||||
With the [official python openai client](https://github.com/openai/openai-python), the address can be set like this:
|
With the [official python openai client](https://github.com/openai/openai-python) (v1.x), the address can be set like this:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import openai
|
from openai import OpenAI
|
||||||
|
|
||||||
openai.api_key = "..."
|
client = OpenAI(
|
||||||
openai.api_base = "http://127.0.0.1:5000/v1"
|
api_key="sk-111111111111111111111111111111111111111111111111",
|
||||||
openai.api_version = "2023-05-15"
|
base_url="http://127.0.0.1:5000/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="x",
|
||||||
|
messages=[{"role": "user", "content": "Hello!"}]
|
||||||
|
)
|
||||||
|
print(response.choices[0].message.content)
|
||||||
```
|
```
|
||||||
|
|
||||||
If using .env files to save the `OPENAI_API_BASE` and `OPENAI_API_KEY` variables, make sure the .env file is loaded before the openai module is imported:
|
With the [official Node.js openai client](https://github.com/openai/openai-node) (v4.x):
|
||||||
|
|
||||||
```python
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
load_dotenv() # make sure the environment variables are set before import
|
|
||||||
import openai
|
|
||||||
```
|
|
||||||
|
|
||||||
With the [official Node.js openai client](https://github.com/openai/openai-node) it is slightly more more complex because the environment variables are not used by default, so small source code changes may be required to use the environment variables, like so:
|
|
||||||
|
|
||||||
```js
|
```js
|
||||||
const openai = OpenAI(
|
import OpenAI from "openai";
|
||||||
Configuration({
|
|
||||||
apiKey: process.env.OPENAI_API_KEY,
|
|
||||||
basePath: process.env.OPENAI_API_BASE
|
|
||||||
})
|
|
||||||
);
|
|
||||||
```
|
|
||||||
|
|
||||||
For apps made with the [chatgpt-api Node.js client library](https://github.com/transitive-bullshit/chatgpt-api):
|
const client = new OpenAI({
|
||||||
|
|
||||||
```js
|
|
||||||
const api = new ChatGPTAPI({
|
|
||||||
apiKey: process.env.OPENAI_API_KEY,
|
apiKey: process.env.OPENAI_API_KEY,
|
||||||
apiBaseUrl: process.env.OPENAI_API_BASE
|
baseURL: "http://127.0.0.1:5000/v1",
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const response = await client.chat.completions.create({
|
||||||
|
model: "x",
|
||||||
|
messages: [{ role: "user", content: "Hello!" }],
|
||||||
|
});
|
||||||
|
console.log(response.choices[0].message.content);
|
||||||
```
|
```
|
||||||
### Embeddings (alpha)
|
### Embeddings (alpha)
|
||||||
|
|
||||||
Embeddings requires `sentence-transformers` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: `sentence-transformers/all-mpnet-base-v2` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default `text-embedding-ada-002` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future.
|
Embeddings requires `sentence-transformers` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: `sentence-transformers/all-mpnet-base-v2` for embeddings. This produces 768 dimensional embeddings. The model is small and fast. This model and embedding size may change in the future.
|
||||||
|
|
||||||
| model name | dimensions | input max tokens | speed | size | Avg. performance |
|
| model name | dimensions | input max tokens | speed | size | Avg. performance |
|
||||||
| ---------------------- | ---------- | ---------------- | ----- | ---- | ---------------- |
|
| ---------------------- | ---------- | ---------------- | ----- | ---- | ---------------- |
|
||||||
| text-embedding-ada-002 | 1536 | 8192 | - | - | - |
|
|
||||||
| text-davinci-002 | 768 | 2046 | - | - | - |
|
|
||||||
| all-mpnet-base-v2 | 768 | 384 | 2800 | 420M | 63.3 |
|
| all-mpnet-base-v2 | 768 | 384 | 2800 | 420M | 63.3 |
|
||||||
| all-MiniLM-L6-v2 | 384 | 256 | 14200 | 80M | 58.8 |
|
| all-MiniLM-L6-v2 | 384 | 256 | 14200 | 80M | 58.8 |
|
||||||
|
|
||||||
|
|
@ -528,50 +561,33 @@ In short, the all-MiniLM-L6-v2 model is 5x faster, 5x smaller ram, 2x smaller st
|
||||||
|
|
||||||
Warning: You cannot mix embeddings from different models even if they have the same dimensions. They are not comparable.
|
Warning: You cannot mix embeddings from different models even if they have the same dimensions. They are not comparable.
|
||||||
|
|
||||||
### Compatibility & not so compatibility
|
### Compatibility
|
||||||
|
|
||||||
Note: the table below may be obsolete.
|
| API endpoint | notes |
|
||||||
|
| ------------------------- | --------------------------------------------------------------------------- |
|
||||||
| API endpoint | tested with | notes |
|
| /v1/chat/completions | Use with instruction-following models. Supports streaming, tool calls. |
|
||||||
| ------------------------- | ---------------------------------- | --------------------------------------------------------------------------- |
|
| /v1/completions | Text completion endpoint. |
|
||||||
| /v1/chat/completions | openai.ChatCompletion.create() | Use it with instruction following models |
|
| /v1/embeddings | Using SentenceTransformer embeddings. |
|
||||||
| /v1/embeddings | openai.Embedding.create() | Using SentenceTransformer embeddings |
|
| /v1/images/generations | Image generation, response_format='b64_json' only. |
|
||||||
| /v1/images/generations | openai.Image.create() | Bare bones, no model configuration, response_format='b64_json' only. |
|
| /v1/moderations | Basic support via embeddings. |
|
||||||
| /v1/moderations | openai.Moderation.create() | Basic initial support via embeddings |
|
| /v1/models | Lists models. Currently loaded model first. |
|
||||||
| /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options |
|
| /v1/models/{id} | Returns model info. |
|
||||||
| /v1/models/{id} | openai.Model.get() | returns whatever you ask for |
|
| /v1/audio/\* | Supported. |
|
||||||
| /v1/edits | openai.Edit.create() | Removed, use /v1/chat/completions instead |
|
| /v1/images/edits | Not yet supported. |
|
||||||
| /v1/text_completion | openai.Completion.create() | Legacy endpoint, variable quality based on the model |
|
| /v1/images/variations | Not yet supported. |
|
||||||
| /v1/completions | openai api completions.create | Legacy endpoint (v0.25) |
|
|
||||||
| /v1/engines/\*/embeddings | python-openai v0.25 | Legacy endpoint |
|
|
||||||
| /v1/engines/\*/generate | openai engines.generate | Legacy endpoint |
|
|
||||||
| /v1/engines | openai engines.list | Legacy Lists models |
|
|
||||||
| /v1/engines/{model_name} | openai engines.get -i {model_name} | You can use this legacy endpoint to load models via the api or command line |
|
|
||||||
| /v1/images/edits | openai.Image.create_edit() | not yet supported |
|
|
||||||
| /v1/images/variations | openai.Image.create_variation() | not yet supported |
|
|
||||||
| /v1/audio/\* | openai.Audio.\* | supported |
|
|
||||||
| /v1/files\* | openai.Files.\* | not yet supported |
|
|
||||||
| /v1/fine-tunes\* | openai.FineTune.\* | not yet supported |
|
|
||||||
| /v1/search | openai.search, engines.search | not yet supported |
|
|
||||||
|
|
||||||
#### Applications
|
#### Applications
|
||||||
|
|
||||||
Almost everything needs the `OPENAI_API_KEY` and `OPENAI_API_BASE` environment variable set, but there are some exceptions.
|
Almost everything needs the `OPENAI_API_KEY` and `OPENAI_API_BASE` environment variables set, but there are some exceptions.
|
||||||
|
|
||||||
Note: the table below may be obsolete.
|
|
||||||
|
|
||||||
| Compatibility | Application/Library | Website | Notes |
|
| Compatibility | Application/Library | Website | Notes |
|
||||||
| ------------- | ---------------------- | ------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
| ------------- | -------------------- | ------------------------------------------------------------------------------ | ----------------------------------------------------------------------------------------- |
|
||||||
| ✅❌ | openai-python (v0.25+) | https://github.com/openai/openai-python | only the endpoints from above are working. OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
| ✅❌ | openai-python | https://github.com/openai/openai-python | Use `OpenAI(base_url="http://127.0.0.1:5000/v1")`. Only the endpoints from above work. |
|
||||||
| ✅❌ | openai-node | https://github.com/openai/openai-node | only the endpoints from above are working. environment variables don't work by default, but can be configured (see above) |
|
| ✅❌ | openai-node | https://github.com/openai/openai-node | Use `new OpenAI({baseURL: "http://127.0.0.1:5000/v1"})`. See example above. |
|
||||||
| ✅❌ | chatgpt-api | https://github.com/transitive-bullshit/chatgpt-api | only the endpoints from above are working. environment variables don't work by default, but can be configured (see above) |
|
| ✅ | anse | https://github.com/anse-app/anse | API Key & URL configurable in UI, Images also work. |
|
||||||
| ✅ | anse | https://github.com/anse-app/anse | API Key & URL configurable in UI, Images also work |
|
| ✅ | shell_gpt | https://github.com/TheR1D/shell_gpt | OPENAI_API_HOST=http://127.0.0.1:5000 |
|
||||||
| ✅ | shell_gpt | https://github.com/TheR1D/shell_gpt | OPENAI_API_HOST=http://127.0.0.1:5001 |
|
| ✅ | gpt-shell | https://github.com/jla/gpt-shell | OPENAI_API_BASE=http://127.0.0.1:5000/v1 |
|
||||||
| ✅ | gpt-shell | https://github.com/jla/gpt-shell | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
| ✅ | gpt-discord-bot | https://github.com/openai/gpt-discord-bot | OPENAI_API_BASE=http://127.0.0.1:5000/v1 |
|
||||||
| ✅ | gpt-discord-bot | https://github.com/openai/gpt-discord-bot | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
| ✅ | OpenAI for Notepad++ | https://github.com/Krazal/nppopenai | api_url=http://127.0.0.1:5000 in the config file, or environment variables. |
|
||||||
| ✅ | OpenAI for Notepad++ | https://github.com/Krazal/nppopenai | api_url=http://127.0.0.1:5001 in the config file, or environment variables |
|
| ✅ | vscode-openai | https://marketplace.visualstudio.com/items?itemName=AndrewButson.vscode-openai | OPENAI_API_BASE=http://127.0.0.1:5000/v1 |
|
||||||
| ✅ | vscode-openai | https://marketplace.visualstudio.com/items?itemName=AndrewButson.vscode-openai | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
| ✅❌ | langchain | https://github.com/hwchase17/langchain | Use `base_url="http://127.0.0.1:5000/v1"`. Results depend on model and prompt formatting. |
|
||||||
| ✅❌ | langchain | https://github.com/hwchase17/langchain | OPENAI_API_BASE=http://127.0.0.1:5001/v1 even with a good 30B-4bit model the result is poor so far. It assumes zero shot python/json coding. Some model tailored prompt formatting improves results greatly. |
|
|
||||||
| ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context |
|
|
||||||
| ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
|
||||||
| ❌ | guidance | https://github.com/microsoft/guidance | logit_bias and logprobs not yet supported |
|
|
||||||
|
|
|
||||||
159
docs/Tool Calling Tutorial.md
Normal file
159
docs/Tool Calling Tutorial.md
Normal file
|
|
@ -0,0 +1,159 @@
|
||||||
|
## Supported models
|
||||||
|
|
||||||
|
The following models are supported:
|
||||||
|
|
||||||
|
- Qwen 3.5
|
||||||
|
- GPT-OSS
|
||||||
|
- Mistral Small / Devstral
|
||||||
|
- DeepSeek V3
|
||||||
|
- Kimi-K2
|
||||||
|
- MiniMax-M2.5
|
||||||
|
- GLM-5
|
||||||
|
- Llama 4
|
||||||
|
|
||||||
|
Other models that output tool calls as JSON (inside XML tags, code blocks, or plain JSON) are also supported through a generic fallback parser.
|
||||||
|
|
||||||
|
## Tool calling in the UI
|
||||||
|
|
||||||
|
### 1. Load a model with tool-calling support
|
||||||
|
|
||||||
|
Load a model with tool-calling support from the Model tab.
|
||||||
|
|
||||||
|
### 2. Select tools
|
||||||
|
|
||||||
|
In the chat sidebar, check the tools you want the model to use:
|
||||||
|
|
||||||
|
- **web_search** -- Search the web using DuckDuckGo.
|
||||||
|
- **fetch_webpage** -- Fetch the content of a URL.
|
||||||
|
- **calculate** -- Evaluate math expressions.
|
||||||
|
- **get_datetime** -- Get the current date and time.
|
||||||
|
- **roll_dice** -- Roll dice.
|
||||||
|
|
||||||
|
### 3. Chat
|
||||||
|
|
||||||
|
Send a message as usual. When the model decides it needs a tool, it will call it automatically. You will see each tool call and its result in a collapsible accordion inside the chat message.
|
||||||
|
|
||||||
|
The model may call multiple tools in sequence before giving its final answer.
|
||||||
|
|
||||||
|
## Writing custom tools
|
||||||
|
|
||||||
|
Each tool is a single `.py` file in `user_data/tools/`. It needs two things:
|
||||||
|
|
||||||
|
1. A `tool` dictionary that describes the function (name, description, parameters).
|
||||||
|
2. An `execute(arguments)` function that runs it and returns the result.
|
||||||
|
|
||||||
|
Here is a minimal example (`user_data/tools/get_datetime.py`):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
tool = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_datetime",
|
||||||
|
"description": "Get the current date and time.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def execute(arguments):
|
||||||
|
now = datetime.now()
|
||||||
|
return {"date": now.strftime("%Y-%m-%d"), "time": now.strftime("%I:%M %p")}
|
||||||
|
```
|
||||||
|
|
||||||
|
An example with parameters (`user_data/tools/roll_dice.py`):
|
||||||
|
|
||||||
|
```python
|
||||||
|
import random
|
||||||
|
|
||||||
|
tool = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "roll_dice",
|
||||||
|
"description": "Roll one or more dice with the specified number of sides.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"count": {"type": "integer", "description": "Number of dice to roll.", "default": 1},
|
||||||
|
"sides": {"type": "integer", "description": "Number of sides per die.", "default": 20},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def execute(arguments):
|
||||||
|
count = max(1, min(arguments.get("count", 1), 1000))
|
||||||
|
sides = max(2, min(arguments.get("sides", 20), 1000))
|
||||||
|
rolls = [random.randint(1, sides) for _ in range(count)]
|
||||||
|
return {"rolls": rolls, "total": sum(rolls)}
|
||||||
|
```
|
||||||
|
|
||||||
|
You can open the built-in tools in `user_data/tools/` for more examples.
|
||||||
|
|
||||||
|
## Tool calling over the API
|
||||||
|
|
||||||
|
Tool calling over the API follows the [OpenAI API](https://platform.openai.com/docs/guides/function-calling) convention. Define your tools, send them with your messages, and handle tool calls in a loop until the model gives a final answer.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import json
|
||||||
|
import requests
|
||||||
|
|
||||||
|
url = "http://127.0.0.1:5000/v1/chat/completions"
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather for a given location.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {"type": "string", "description": "City name"},
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def execute_tool(name, arguments):
|
||||||
|
if name == "get_weather":
|
||||||
|
return {"temperature": "14°C", "condition": "partly cloudy"}
|
||||||
|
return {"error": f"Unknown tool: {name}"}
|
||||||
|
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "What's the weather like in Paris?"}]
|
||||||
|
|
||||||
|
for _ in range(10):
|
||||||
|
response = requests.post(url, json={"messages": messages, "tools": tools}).json()
|
||||||
|
choice = response["choices"][0]
|
||||||
|
|
||||||
|
if choice["finish_reason"] == "tool_calls":
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": choice["message"]["content"],
|
||||||
|
"tool_calls": choice["message"]["tool_calls"],
|
||||||
|
})
|
||||||
|
|
||||||
|
for tool_call in choice["message"]["tool_calls"]:
|
||||||
|
name = tool_call["function"]["name"]
|
||||||
|
arguments = json.loads(tool_call["function"]["arguments"])
|
||||||
|
result = execute_tool(name, arguments)
|
||||||
|
print(f"Tool call: {name}({arguments}) => {result}")
|
||||||
|
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_call["id"],
|
||||||
|
"content": json.dumps(result),
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
print(f"\nAssistant: {choice['message']['content']}")
|
||||||
|
break
|
||||||
|
```
|
||||||
|
|
@ -1,20 +1,17 @@
|
||||||
## What Works
|
## What Works
|
||||||
|
|
||||||
| Loader | Loading 1 LoRA | Loading 2 or more LoRAs | Training LoRAs | Multimodal extension | Perplexity evaluation |
|
| Loader | Loading LoRAs | Training LoRAs | Multimodal | Perplexity evaluation |
|
||||||
|----------------|----------------|-------------------------|----------------|----------------------|-----------------------|
|
|----------------|---------------|----------------|------------|-----------------------|
|
||||||
| Transformers | ✅ | ✅\*\* | ✅\* | ✅ | ✅ |
|
| llama.cpp | ❌ | ❌ | ✅\* | ❌ |
|
||||||
| llama.cpp | ❌ | ❌ | ❌ | ❌ | use llamacpp_HF |
|
| Transformers | ✅ | ✅ | ✅\*\* | ✅ |
|
||||||
| llamacpp_HF | ❌ | ❌ | ❌ | ❌ | ✅ |
|
| ExLlamav3_HF | ❌ | ❌ | ❌ | ✅ |
|
||||||
| ExLlamav2_HF | ✅ | ✅ | ❌ | ❌ | ✅ |
|
| ExLlamav3 | ❌ | ❌ | ✅ | ❌ |
|
||||||
| ExLlamav2 | ✅ | ✅ | ❌ | ❌ | use ExLlamav2_HF |
|
| TensorRT-LLM | ❌ | ❌ | ❌ | ❌ |
|
||||||
| AutoGPTQ | ✅ | ❌ | ❌ | ✅ | ✅ |
|
|
||||||
| AutoAWQ | ? | ❌ | ? | ? | ✅ |
|
|
||||||
| HQQ | ? | ? | ? | ? | ✅ |
|
|
||||||
|
|
||||||
❌ = not implemented
|
❌ = not supported
|
||||||
|
|
||||||
✅ = implemented
|
✅ = supported
|
||||||
|
|
||||||
\* Training LoRAs with GPTQ models also works with the Transformers loader. Make sure to check "auto-devices" and "disable_exllama" before loading the model.
|
\* Via the `mmproj` parameter (multimodal projector file).
|
||||||
|
|
||||||
\*\* Multi-LoRA in PEFT is tricky and the current implementation does not work reliably in all cases.
|
\*\* Via the `send_pictures` extension.
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,8 @@ from requests.adapters import HTTPAdapter
|
||||||
from requests.exceptions import ConnectionError, RequestException, Timeout
|
from requests.exceptions import ConnectionError, RequestException, Timeout
|
||||||
from tqdm.contrib.concurrent import thread_map
|
from tqdm.contrib.concurrent import thread_map
|
||||||
|
|
||||||
|
from modules.paths import resolve_user_data_dir
|
||||||
|
|
||||||
base = os.environ.get("HF_ENDPOINT") or "https://huggingface.co"
|
base = os.environ.get("HF_ENDPOINT") or "https://huggingface.co"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -182,11 +184,13 @@ class ModelDownloader:
|
||||||
is_llamacpp = has_gguf and specific_file is not None
|
is_llamacpp = has_gguf and specific_file is not None
|
||||||
return links, sha256, is_lora, is_llamacpp, file_sizes
|
return links, sha256, is_lora, is_llamacpp, file_sizes
|
||||||
|
|
||||||
def get_output_folder(self, model, branch, is_lora, is_llamacpp=False, model_dir=None):
|
def get_output_folder(self, model, branch, is_lora, is_llamacpp=False, model_dir=None, user_data_dir=None):
|
||||||
if model_dir:
|
if model_dir:
|
||||||
base_folder = model_dir
|
base_folder = model_dir
|
||||||
else:
|
else:
|
||||||
base_folder = 'user_data/models' if not is_lora else 'user_data/loras'
|
if user_data_dir is None:
|
||||||
|
user_data_dir = resolve_user_data_dir()
|
||||||
|
base_folder = str(user_data_dir / 'models') if not is_lora else str(user_data_dir / 'loras')
|
||||||
|
|
||||||
# If the model is of type GGUF, save directly in the base_folder
|
# If the model is of type GGUF, save directly in the base_folder
|
||||||
if is_llamacpp:
|
if is_llamacpp:
|
||||||
|
|
@ -392,7 +396,8 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).')
|
parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).')
|
||||||
parser.add_argument('--exclude-pattern', type=str, default=None, help='Regex pattern to exclude files from download.')
|
parser.add_argument('--exclude-pattern', type=str, default=None, help='Regex pattern to exclude files from download.')
|
||||||
parser.add_argument('--output', type=str, default=None, help='Save the model files to this folder.')
|
parser.add_argument('--output', type=str, default=None, help='Save the model files to this folder.')
|
||||||
parser.add_argument('--model-dir', type=str, default=None, help='Save the model files to a subfolder of this folder instead of the default one (text-generation-webui/user_data/models).')
|
parser.add_argument('--model-dir', type=str, default=None, help='Save the model files to a subfolder of this folder instead of the default one (user_data/models).')
|
||||||
|
parser.add_argument('--user-data-dir', type=str, default=None, help='Path to the user data directory. Overrides auto-detection.')
|
||||||
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
||||||
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
||||||
parser.add_argument('--max-retries', type=int, default=7, help='Max retries count when get error in download time.')
|
parser.add_argument('--max-retries', type=int, default=7, help='Max retries count when get error in download time.')
|
||||||
|
|
@ -408,6 +413,26 @@ if __name__ == '__main__':
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
downloader = ModelDownloader(max_retries=args.max_retries)
|
downloader = ModelDownloader(max_retries=args.max_retries)
|
||||||
|
|
||||||
|
# Handle direct file URLs (e.g. https://huggingface.co/org/repo/resolve/branch/file.gguf)
|
||||||
|
if '/resolve/' in model:
|
||||||
|
url = model if model.startswith('http') else f'{base}/{model}'
|
||||||
|
url = url.split('?')[0]
|
||||||
|
filename = url.split('/')[-1]
|
||||||
|
|
||||||
|
if args.output:
|
||||||
|
output_folder = Path(args.output)
|
||||||
|
elif args.model_dir:
|
||||||
|
output_folder = Path(args.model_dir)
|
||||||
|
else:
|
||||||
|
user_data_dir = Path(args.user_data_dir) if args.user_data_dir else resolve_user_data_dir()
|
||||||
|
output_folder = user_data_dir / 'models'
|
||||||
|
|
||||||
|
output_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
print(f"Downloading {filename} to {output_folder}")
|
||||||
|
downloader.get_single_file(url, output_folder, start_from_scratch=args.clean)
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
# Clean up the model/branch names
|
# Clean up the model/branch names
|
||||||
try:
|
try:
|
||||||
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
|
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
|
||||||
|
|
@ -421,10 +446,11 @@ if __name__ == '__main__':
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the output folder
|
# Get the output folder
|
||||||
|
user_data_dir = Path(args.user_data_dir) if args.user_data_dir else None
|
||||||
if args.output:
|
if args.output:
|
||||||
output_folder = Path(args.output)
|
output_folder = Path(args.output)
|
||||||
else:
|
else:
|
||||||
output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp, model_dir=args.model_dir)
|
output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp, model_dir=args.model_dir, user_data_dir=user_data_dir)
|
||||||
|
|
||||||
if args.check:
|
if args.check:
|
||||||
# Check previously downloaded files
|
# Check previously downloaded files
|
||||||
|
|
|
||||||
|
|
@ -1,92 +0,0 @@
|
||||||
# Training_PRO
|
|
||||||
|
|
||||||
This is an expanded and reworked Training tab
|
|
||||||
Maintained by FP
|
|
||||||
|
|
||||||
[](https://ko-fi.com/Q5Q5MOB4M)
|
|
||||||
|
|
||||||
Repo home:
|
|
||||||
|
|
||||||
https://github.com/FartyPants/Training_PRO
|
|
||||||
|
|
||||||
In general the repo above is ahead of the extension included in text WebUi.
|
|
||||||
|
|
||||||
## News
|
|
||||||
|
|
||||||
- NEFtune: add noise to help with generalization
|
|
||||||
- Loss Graph in interface.
|
|
||||||
- Supports Mistral training
|
|
||||||
- some roundabout around pytorch and transformers version desync
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
## Features/Changes
|
|
||||||
|
|
||||||
- Chunking: precise raw text slicer (PRTS) uses sentence slicing and making sure things are clean on all ends
|
|
||||||
- overlap chunking - this special overlapping will make additional overlap block based on logical rules (aka no overlap block on hard cut)
|
|
||||||
- custom scheduler (follow the code to make your own) In LR Scheduler select FP_low_epoch_annealing - this scheduler will keep the LR constant for first epoch then use cosine for the rest - this part would be best to spawn into a new py file
|
|
||||||
- saves graph png file at the end with learning rate and loss per epoch
|
|
||||||
- adding EOS to each block or to hard cut only
|
|
||||||
- automatically lowers gradient accumulation if you go overboard and set gradient accumulation that will be higher than actual data - transformers would then throw error (or they used to, not sure if still true) but in any way, it will fix bad data
|
|
||||||
- turn BOS on and OFF
|
|
||||||
- target selector
|
|
||||||
- DEMENTOR LEARNING (experimental) Deep Memorization Enforcement Through Overlapping and Repetition. This is an experiment for long-text learning using low epochs (basically use 1 epoch with constant LR or 2 epochs with FP_low_epoch_annealing LR scheduler)
|
|
||||||
- Getting rid of micro batch size/batch size confusion. Now there is True Batch Size and Gradient accumulation slider, consisten with all the other training out there
|
|
||||||
- Ability to save Checkpoint during training with a button
|
|
||||||
- Ability to change Stop Loss during training
|
|
||||||
- different modes of checkpoint auto saving
|
|
||||||
- Function to Check Dataset and suggest parameters such as warmup and checkpoint save frequency before training
|
|
||||||
- Graph Training Loss in interface
|
|
||||||
- more custom schedulers
|
|
||||||
|
|
||||||
### Notes:
|
|
||||||
|
|
||||||
This uses it's own chunking code for raw text based on sentence splitting. This will avoid weird cuts in the chunks and each chunk should now start with sentence and end on some sentence. It works hand in hand with Hard Cut. A propper use is to structure your text into logical blocks (ideas) separated by three \n then use three \n in hard cut. This way each chunk will contain only one flow of ideas and not derail in the thoughts. And Overlapping code will create overlapped blocks on sentence basis too, but not cross hard cut, thus not cross different ideas either. Does it make any sense? No? Hmmmm...
|
|
||||||
|
|
||||||
### Custom schedulers
|
|
||||||
|
|
||||||
A bunch of custom (combination) schedulers are added to the LR schedule. These are based on my own experiments
|
|
||||||
|
|
||||||
**FP_low_epoch_annealing**
|
|
||||||
|
|
||||||
Uses constant LR (with warmup) for 1 epoch only. The rest of the epoch(s) is cosine annealing. So 10 epochs - 1 will be constant 9 will be nose dive down. However a typical usage would be 2 epochs (hence low epoch in name). 1st is constant, the second is annealing. Simple. I use it 90% of time.
|
|
||||||
|
|
||||||
**FP_half_time_annealing**
|
|
||||||
|
|
||||||
Like the low epoch, but now the total number of steps is divided by 2. First half is constant, second half is annealing. So 10 epochs - 5 will be constant, 5 will be cosine nose down.
|
|
||||||
|
|
||||||
**FP_raise_fall_creative**
|
|
||||||
|
|
||||||
This is a sine raise till half of the total steps then cosine fall the rest. (Or you may think of the curve as sine in its entirety. The most learning is done in the hump, in the middle. The warmup entry has no effect, since sine is automatically warm up.
|
|
||||||
The idea is to start very mildly as not to overfit with the first blocks of dataset. It seems to broaden the scope of the model making it less strict for tight dataset.
|
|
||||||
|
|
||||||
### Targets
|
|
||||||
|
|
||||||
Normal LORA is q, v and that's what you should use. You can use (q k v o) or (q k v) and it will give you a lot more trainable parameters. The benefit is that you can keep rank lower and still attain the same coherency as q v with high rank. Guanaco has been trained with QLORA and q k v o for example and they swear by it.
|
|
||||||
|
|
||||||
### DEMENTOR LEARNING (experimental) Deep Memorization Enforcement Through Overlapping and Repetition
|
|
||||||
|
|
||||||
This is and experimental chunking to train long-form text in low number of epochs (basically 1) with sliding repetition. The depth of learning directly depends on the cutoff_length. Increasing cutoff length will also increase number of blocks created from long-form text (which is contrary to normal training). It is based on my own wild experiments.
|
|
||||||
|
|
||||||
### Getting rid of batch size and micro batch size
|
|
||||||
|
|
||||||
Keeping consistency with everyone else.
|
|
||||||
|
|
||||||
Listen, There is only ONE batch size - the True batch size (called previously micro-batch size in WebUI) - this is how many blocks are processed at once (during a single step). It eats GPU, but it really helps with the quality training (in fact the ideal batch size would be the same as number of blocks - which is unrealistic) - so the idea is to cram as much True Batch Size before your GPU blows with OOM. On 24GB this is about 10 for 13b (loaded with 4-bit)
|
|
||||||
|
|
||||||
So no micro batch size - it is now called True Batch Size, because that's what it is.
|
|
||||||
|
|
||||||
The other thing is Gradient Accumulation - this is an emulation of the above Batch size - a virtual batch size, if you will. If your GPU can't handle real batch size then you may fake it using Gradient Accumulation. This will accumulate the gradients over so many steps defined here and then update the weights at the end without increase in GPU.
|
|
||||||
Gradient accumulation is like a virtual Batch size multiplier without the GPU penalty.
|
|
||||||
|
|
||||||
If your batch size is 4 and your gradient accumulation is 2 then it sort of behaves as if we have batch size 8. *Sort of* because Batch size of 4 and GA of 2 is NOT the same as batch size of 2 and GA of 4. (It produces different weights - hence it's not an equivalent). The idea is that if you don't have GPU - using GA to extend batch size is the next best thing (good enough) since you have no other choice.
|
|
||||||
|
|
||||||
If all you can afford is 1 batch size, then increasing GA will likely make the learning better in some range of GA (it's not always more is better).
|
|
||||||
|
|
||||||
However - GA is not some golden goose. As said, it isn't the same as batch size. In fact GA may worsen your learning as well.
|
|
||||||
|
|
||||||
I would suggest a series of experiment where you would put batch size as high as possible without OOM, set GA 1, then repeat training while increasing the GA (2, 4...), and see how the model changes. It's likely that it would follow some sort of curve where GA will seem to help before it will make it worse. Some people believe that if you can squeeze 6 BATCH Size, then you should not bother with GA at all... YMMW
|
|
||||||
|
|
||||||
High Batch Size vs High GA would also likely produce different results in terms of learning words vs style. How? Hmmmm... good question.
|
|
||||||
|
|
||||||
One optical "benefit" of GA is that the loss will fluctuate less (because of all the gradient accumulation, which works as a form of noise smoothing as well).
|
|
||||||
|
|
@ -1,433 +0,0 @@
|
||||||
from functools import partial
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
import math
|
|
||||||
from torch.optim.lr_scheduler import LambdaLR
|
|
||||||
|
|
||||||
from peft import (
|
|
||||||
PeftModel,
|
|
||||||
)
|
|
||||||
|
|
||||||
RED = "\033[91m"
|
|
||||||
YELLOW = "\033[93m"
|
|
||||||
GREEN = "\033[92m"
|
|
||||||
RESET = "\033[0m"
|
|
||||||
|
|
||||||
last_print_label = ''
|
|
||||||
|
|
||||||
custom_scheduler_params = {'trigger_loss': 0.0, 'ramp_down_ratio':1.0, 'current_loss': 0.0,'dynamic_scheduler_stop': False, 'calc_ramp_down_at_step': 0, 'calc_num_training_steps': 0}
|
|
||||||
|
|
||||||
|
|
||||||
def custom_scheduler_global_update(current_loss: float):
|
|
||||||
custom_scheduler_params.update({'current_loss': current_loss})
|
|
||||||
|
|
||||||
def custom_scheduler_global_setup(trigger_loss: float, ramp_down_ratio: float):
|
|
||||||
custom_scheduler_params.update({'trigger_loss': trigger_loss})
|
|
||||||
custom_scheduler_params.update({'ramp_down_ratio': ramp_down_ratio})
|
|
||||||
|
|
||||||
# calculates the total num steps after trigger
|
|
||||||
custom_scheduler_params.update({'calc_num_training_steps': 0})
|
|
||||||
#calculates steps when the ramp_down trigger occured
|
|
||||||
custom_scheduler_params.update({'calc_ramp_down_at_step': 0})
|
|
||||||
# triggers scheduler stopping after it reached calc_num_training_steps
|
|
||||||
custom_scheduler_params.update({'dynamic_scheduler_stop': False})
|
|
||||||
|
|
||||||
|
|
||||||
# hold constant to the half of epochs then cosine down to 0
|
|
||||||
def _get_fp_half_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_firstepoch_steps: int):
|
|
||||||
|
|
||||||
global last_print_label
|
|
||||||
print_label = ''
|
|
||||||
|
|
||||||
half_steps = num_training_steps//2
|
|
||||||
|
|
||||||
num_warmup_steps = min(num_warmup_steps,half_steps)
|
|
||||||
|
|
||||||
if current_step < num_warmup_steps:
|
|
||||||
print_label = 'Scheduler: Warmup'
|
|
||||||
elif current_step < half_steps:
|
|
||||||
print_label = 'Scheduler: Hold'
|
|
||||||
else:
|
|
||||||
print_label = 'Scheduler: Annealing'
|
|
||||||
|
|
||||||
if print_label != last_print_label:
|
|
||||||
print(print_label)
|
|
||||||
|
|
||||||
last_print_label = print_label
|
|
||||||
|
|
||||||
if current_step < num_warmup_steps:
|
|
||||||
return float(current_step) / float(max(1, num_warmup_steps))
|
|
||||||
|
|
||||||
if current_step < half_steps:
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
progress = float(current_step - half_steps) / float(max(1, num_training_steps - half_steps))
|
|
||||||
num_cycles = 0.5
|
|
||||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
|
||||||
|
|
||||||
|
|
||||||
# raise up in cosine, then fall back in cosine
|
|
||||||
def _get_fp_cosine_raise_and_fall_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_firstepoch_steps: int):
|
|
||||||
|
|
||||||
global last_print_label
|
|
||||||
print_label = ''
|
|
||||||
|
|
||||||
half_steps = num_training_steps//2
|
|
||||||
|
|
||||||
#num_warmup_steps = min(num_warmup_steps,half_steps)
|
|
||||||
|
|
||||||
if current_step < half_steps:
|
|
||||||
print_label = 'Scheduler: Raise'
|
|
||||||
else:
|
|
||||||
print_label = 'Scheduler: Fall'
|
|
||||||
|
|
||||||
if print_label != last_print_label:
|
|
||||||
print(print_label)
|
|
||||||
|
|
||||||
last_print_label = print_label
|
|
||||||
|
|
||||||
|
|
||||||
# linear
|
|
||||||
# return float(current_step) / float(max(1, num_warmup_steps))
|
|
||||||
|
|
||||||
progress = float(current_step - half_steps) / float(max(1, num_training_steps - half_steps))
|
|
||||||
num_cycles = 0.5
|
|
||||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
|
||||||
|
|
||||||
# constant to the first epochs then cosine down to 0 over the rest epochs
|
|
||||||
def _get_fp_cosine_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_firstepoch_steps: int):
|
|
||||||
|
|
||||||
global last_print_label
|
|
||||||
print_label = ''
|
|
||||||
|
|
||||||
num_warmup_steps = min(num_warmup_steps,num_firstepoch_steps)
|
|
||||||
|
|
||||||
if current_step < num_warmup_steps:
|
|
||||||
print_label = 'Scheduler: Warmup'
|
|
||||||
elif current_step < num_firstepoch_steps:
|
|
||||||
print_label = 'Scheduler: Hold'
|
|
||||||
else:
|
|
||||||
print_label = 'Scheduler: Annealing'
|
|
||||||
|
|
||||||
if print_label != last_print_label:
|
|
||||||
print(print_label)
|
|
||||||
|
|
||||||
last_print_label = print_label
|
|
||||||
|
|
||||||
if current_step < num_warmup_steps:
|
|
||||||
return float(current_step) / float(max(1, num_warmup_steps))
|
|
||||||
|
|
||||||
if current_step < num_firstepoch_steps:
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
progress = float(current_step - num_firstepoch_steps) / float(max(1, num_training_steps - num_firstepoch_steps))
|
|
||||||
num_cycles = 0.5
|
|
||||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
|
||||||
|
|
||||||
# halve lr each epoch
|
|
||||||
|
|
||||||
def _get_fp_cdrop_rate_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_firstepoch_steps: int):
|
|
||||||
|
|
||||||
global last_print_label
|
|
||||||
print_label = ''
|
|
||||||
|
|
||||||
num_warmup_steps = min(num_warmup_steps, num_firstepoch_steps)
|
|
||||||
|
|
||||||
current_epoch = (current_step // num_firstepoch_steps) + 1
|
|
||||||
|
|
||||||
|
|
||||||
if current_step < num_warmup_steps:
|
|
||||||
print_label = 'Scheduler: Warmup'
|
|
||||||
elif current_step < num_firstepoch_steps:
|
|
||||||
print_label = 'Scheduler: Hold'
|
|
||||||
else:
|
|
||||||
print_label = 'Scheduler: Drop Rate'
|
|
||||||
|
|
||||||
if print_label != last_print_label:
|
|
||||||
print(print_label)
|
|
||||||
|
|
||||||
last_print_label = print_label
|
|
||||||
|
|
||||||
if current_step < num_warmup_steps:
|
|
||||||
return float(current_step) / float(max(1, num_warmup_steps))
|
|
||||||
|
|
||||||
if current_step < num_firstepoch_steps:
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
# Compute the learning rate for the annealing phase
|
|
||||||
|
|
||||||
learning_rate = 1.0 / float(2 ** (current_epoch - 1))
|
|
||||||
|
|
||||||
return learning_rate
|
|
||||||
|
|
||||||
# epoch decay: 1/(1 + decay * epoch)
|
|
||||||
|
|
||||||
def custom_cosine_scheduler_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_firstepoch_steps, last_epoch=-1):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
optimizer ([`~torch.optim.Optimizer`]):
|
|
||||||
The optimizer for which to schedule the learning rate.
|
|
||||||
num_warmup_steps (`int`):
|
|
||||||
The number of steps for the warmup phase.
|
|
||||||
num_training_steps (`int`):
|
|
||||||
The total number of training steps.
|
|
||||||
last_epoch (`int`, *optional*, defaults to -1):
|
|
||||||
The index of the last epoch when resuming training.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
|
||||||
"""
|
|
||||||
|
|
||||||
lr_lambda = partial(
|
|
||||||
_get_fp_cosine_schedule_with_warmup_lr_lambda,
|
|
||||||
num_warmup_steps=num_warmup_steps,
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
num_firstepoch_steps = num_firstepoch_steps,
|
|
||||||
)
|
|
||||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
|
||||||
|
|
||||||
def custom_half_scheduler_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_firstepoch_steps, last_epoch=-1):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
optimizer ([`~torch.optim.Optimizer`]):
|
|
||||||
The optimizer for which to schedule the learning rate.
|
|
||||||
num_warmup_steps (`int`):
|
|
||||||
The number of steps for the warmup phase.
|
|
||||||
num_training_steps (`int`):
|
|
||||||
The total number of training steps.
|
|
||||||
last_epoch (`int`, *optional*, defaults to -1):
|
|
||||||
The index of the last epoch when resuming training.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
|
||||||
"""
|
|
||||||
|
|
||||||
lr_lambda = partial(
|
|
||||||
_get_fp_half_schedule_with_warmup_lr_lambda,
|
|
||||||
num_warmup_steps=num_warmup_steps,
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
num_firstepoch_steps = num_firstepoch_steps,
|
|
||||||
)
|
|
||||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
|
||||||
|
|
||||||
def custom_raise_fall_scheduler_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_firstepoch_steps, last_epoch=-1):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
optimizer ([`~torch.optim.Optimizer`]):
|
|
||||||
The optimizer for which to schedule the learning rate.
|
|
||||||
num_warmup_steps (`int`):
|
|
||||||
The number of steps for the warmup phase.
|
|
||||||
num_training_steps (`int`):
|
|
||||||
The total number of training steps.
|
|
||||||
last_epoch (`int`, *optional*, defaults to -1):
|
|
||||||
The index of the last epoch when resuming training.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
|
||||||
"""
|
|
||||||
|
|
||||||
lr_lambda = partial(
|
|
||||||
_get_fp_cosine_raise_and_fall_lr_lambda,
|
|
||||||
num_warmup_steps=num_warmup_steps,
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
num_firstepoch_steps = num_firstepoch_steps,
|
|
||||||
)
|
|
||||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
|
||||||
|
|
||||||
|
|
||||||
def neftune_forward(self, input: torch.Tensor):
|
|
||||||
"""
|
|
||||||
Implements the NEFTune forward pass for the model. Note this works only for
|
|
||||||
torch.nn.Embedding layers. This method is slightly adapted from the original source code
|
|
||||||
that can be found here: https://github.com/neelsjain/NEFTune
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input (`torch.Tensor`):
|
|
||||||
The input tensor to the model.
|
|
||||||
noise_alpha (`float`):
|
|
||||||
The noise alpha value to use for the NEFTune forward pass.
|
|
||||||
"""
|
|
||||||
embeddings = torch.nn.functional.embedding(
|
|
||||||
input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
# Add noise to the embeddings
|
|
||||||
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
|
|
||||||
mag_norm = self.neftune_noise_alpha / torch.sqrt(dims)
|
|
||||||
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(-mag_norm, mag_norm)
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
class FPNEFtuneTrainer(transformers.Trainer):
|
|
||||||
def __init__(self,neftune_noise_alpha:float = 0.0, model = None, *args, **kwargs):
|
|
||||||
self.neftune_noise_alpha = neftune_noise_alpha
|
|
||||||
if self.neftune_noise_alpha > 0.0:
|
|
||||||
model = self._activate_neftune(model)
|
|
||||||
super().__init__(model = model, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _activate_neftune(self, model):
|
|
||||||
r"""
|
|
||||||
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: https://arxiv.org/abs/2310.05914
|
|
||||||
"""
|
|
||||||
print(f"Activating {RED}NEFtune{RESET} with scale: {self.neftune_noise_alpha}")
|
|
||||||
if isinstance(model, transformers.PreTrainedModel):
|
|
||||||
embeddings = model.get_input_embeddings()
|
|
||||||
elif isinstance(model, PeftModel):
|
|
||||||
embeddings = model.base_model.get_input_embeddings()
|
|
||||||
|
|
||||||
embeddings.neftune_noise_alpha = self.neftune_noise_alpha
|
|
||||||
old_forward = embeddings.forward
|
|
||||||
|
|
||||||
# This hack seems to be needed to properly use a custom forward pass
|
|
||||||
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
|
|
||||||
bound_method = neftune_forward.__get__(embeddings, embeddings.__class__)
|
|
||||||
setattr(embeddings, "forward", bound_method)
|
|
||||||
|
|
||||||
# embeddings.forward = neftune_forward
|
|
||||||
embeddings._trl_old_forward = old_forward
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
def train(self, *args, **kwargs):
|
|
||||||
output = super().train(*args, **kwargs)
|
|
||||||
|
|
||||||
# After training we make sure to retrieve back the original forward pass method
|
|
||||||
# for the embedding layer
|
|
||||||
if self.neftune_noise_alpha is not None:
|
|
||||||
|
|
||||||
if isinstance(self.model, transformers.PreTrainedModel):
|
|
||||||
embeddings = self.model.get_input_embeddings()
|
|
||||||
elif isinstance(self.model, PeftModel):
|
|
||||||
embeddings = self.model.base_model.get_input_embeddings()
|
|
||||||
|
|
||||||
if hasattr(embeddings, "_trl_old_forward"):
|
|
||||||
embeddings.forward = embeddings._trl_old_forward
|
|
||||||
del embeddings._trl_old_forward
|
|
||||||
del embeddings.neftune_noise_alpha
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class FPSchedulerTrainer(transformers.Trainer):
|
|
||||||
def __init__(self,neftune_noise_alpha:float = 0.0, model = None, *args, **kwargs):
|
|
||||||
self.neftune_noise_alpha = neftune_noise_alpha
|
|
||||||
if self.neftune_noise_alpha > 0.0:
|
|
||||||
model = self._activate_neftune(model)
|
|
||||||
super().__init__(model = model, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _activate_neftune(self, model):
|
|
||||||
r"""
|
|
||||||
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: https://arxiv.org/abs/2310.05914
|
|
||||||
"""
|
|
||||||
print(f"Activating {RED}NEFtune{RESET} with scale: {self.neftune_noise_alpha}")
|
|
||||||
if isinstance(model, transformers.PreTrainedModel):
|
|
||||||
embeddings = model.get_input_embeddings()
|
|
||||||
elif isinstance(model, PeftModel):
|
|
||||||
embeddings = model.base_model.get_input_embeddings()
|
|
||||||
|
|
||||||
embeddings.neftune_noise_alpha = self.neftune_noise_alpha
|
|
||||||
old_forward = embeddings.forward
|
|
||||||
|
|
||||||
# This hack seems to be needed to properly use a custom forward pass
|
|
||||||
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
|
|
||||||
bound_method = neftune_forward.__get__(embeddings, embeddings.__class__)
|
|
||||||
setattr(embeddings, "forward", bound_method)
|
|
||||||
|
|
||||||
# embeddings.forward = neftune_forward
|
|
||||||
embeddings._trl_old_forward = old_forward
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
def train(self, *args, **kwargs):
|
|
||||||
output = super().train(*args, **kwargs)
|
|
||||||
|
|
||||||
# After training we make sure to retrieve back the original forward pass method
|
|
||||||
# for the embedding layer
|
|
||||||
if self.neftune_noise_alpha is not None:
|
|
||||||
|
|
||||||
if isinstance(self.model, transformers.PreTrainedModel):
|
|
||||||
embeddings = self.model.get_input_embeddings()
|
|
||||||
elif isinstance(self.model, PeftModel):
|
|
||||||
embeddings = self.model.base_model.get_input_embeddings()
|
|
||||||
|
|
||||||
if hasattr(embeddings, "_trl_old_forward"):
|
|
||||||
embeddings.forward = embeddings._trl_old_forward
|
|
||||||
del embeddings._trl_old_forward
|
|
||||||
del embeddings.neftune_noise_alpha
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
|
|
||||||
#Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or passed as an argument.
|
|
||||||
|
|
||||||
num_train_epochs = self.args.num_train_epochs
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps)
|
|
||||||
num_firstepoch_steps = math.ceil(num_training_steps/num_train_epochs)
|
|
||||||
num_warmup_acc = num_warmup_steps*self.args.gradient_accumulation_steps
|
|
||||||
num_firstepoch_steps_acc = num_firstepoch_steps*self.args.gradient_accumulation_steps
|
|
||||||
num_training_steps_acc = num_training_steps*self.args.gradient_accumulation_steps
|
|
||||||
|
|
||||||
custom_scheduler_params.update({'dynamic_scheduler_stop': False})
|
|
||||||
|
|
||||||
print (f"Warm-up steps aligned to Gradient accumulation ({self.args.gradient_accumulation_steps}) = {num_warmup_acc} actual warmup steps")
|
|
||||||
if self.args.lr_scheduler_type == 'cosine':
|
|
||||||
|
|
||||||
num_warmup_acc_min = min(num_warmup_acc, num_firstepoch_steps_acc)
|
|
||||||
|
|
||||||
if num_warmup_acc>num_firstepoch_steps_acc:
|
|
||||||
print(f"\033[1;31;1mWARNING: The number of warmup steps is set too high! It will be clamped to 1 epoch, essentially going from warmup to annealing.\033[0;37;0m")
|
|
||||||
print (f"FP Scheduler Warmup: 0-[{num_warmup_acc_min}], Hold [{num_warmup_acc_min}]-{num_firstepoch_steps_acc}, Annealing {num_firstepoch_steps_acc}-{num_training_steps_acc}")
|
|
||||||
else:
|
|
||||||
print (f"FP Scheduler Warmup: 0-{num_warmup_acc_min}, Hold {num_warmup_acc_min}-{num_firstepoch_steps_acc}, Annealing {num_firstepoch_steps_acc}-{num_training_steps_acc}")
|
|
||||||
|
|
||||||
self.lr_scheduler = custom_cosine_scheduler_with_warmup(
|
|
||||||
optimizer=self.optimizer if optimizer is None else optimizer,
|
|
||||||
num_warmup_steps=num_warmup_steps,
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
num_firstepoch_steps = num_firstepoch_steps,
|
|
||||||
)
|
|
||||||
self._created_lr_scheduler = True
|
|
||||||
return self.lr_scheduler
|
|
||||||
elif self.args.lr_scheduler_type == 'constant':
|
|
||||||
|
|
||||||
half_step_acc = num_training_steps_acc//2
|
|
||||||
num_warmup_acc_min = min(num_warmup_acc, half_step_acc)
|
|
||||||
|
|
||||||
if num_warmup_acc>half_step_acc:
|
|
||||||
print(f"\033[1;31;1mWARNING: The number of warmup steps is set too high! It will be clamped to half of all epochs, essentially going from warmup to annealing in the middle.\033[0;37;0m")
|
|
||||||
print (f"FP Scheduler Warmup: 0-[{num_warmup_acc_min}], Hold [{num_warmup_acc_min}]-{half_step_acc}, Annealing {half_step_acc}-{num_training_steps_acc}")
|
|
||||||
else:
|
|
||||||
print (f"FP Scheduler Warmup: 0-{num_warmup_acc_min}, Hold {num_warmup_acc_min}-{half_step_acc}, Annealing {half_step_acc}-{num_training_steps_acc}")
|
|
||||||
|
|
||||||
self.lr_scheduler = custom_half_scheduler_with_warmup(
|
|
||||||
optimizer=self.optimizer if optimizer is None else optimizer,
|
|
||||||
num_warmup_steps=num_warmup_steps,
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
num_firstepoch_steps = num_firstepoch_steps,
|
|
||||||
)
|
|
||||||
self._created_lr_scheduler = True
|
|
||||||
return self.lr_scheduler
|
|
||||||
elif self.args.lr_scheduler_type == 'constant_with_warmup':
|
|
||||||
|
|
||||||
half_step_acc = num_training_steps_acc//2
|
|
||||||
|
|
||||||
if num_warmup_steps>0:
|
|
||||||
print(f"Warmup doesn't apply to this scheduler [Raise-Fall]")
|
|
||||||
|
|
||||||
print (f"Scheduler Raise: 0-{half_step_acc}, Fall {half_step_acc}-{num_training_steps_acc}")
|
|
||||||
|
|
||||||
self.lr_scheduler = custom_raise_fall_scheduler_with_warmup(
|
|
||||||
optimizer=self.optimizer if optimizer is None else optimizer,
|
|
||||||
num_warmup_steps=num_warmup_steps,
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
num_firstepoch_steps = num_firstepoch_steps,
|
|
||||||
)
|
|
||||||
self._created_lr_scheduler = True
|
|
||||||
return self.lr_scheduler
|
|
||||||
else:
|
|
||||||
return super().create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
|
|
||||||
|
|
@ -1,62 +0,0 @@
|
||||||
import os
|
|
||||||
import json
|
|
||||||
|
|
||||||
def create_graph(lora_path, lora_name):
|
|
||||||
try:
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from matplotlib.ticker import ScalarFormatter
|
|
||||||
|
|
||||||
peft_model_path = f'{lora_path}/training_graph.json'
|
|
||||||
image_model_path = f'{lora_path}/training_graph.png'
|
|
||||||
# Check if the JSON file exists
|
|
||||||
if os.path.exists(peft_model_path):
|
|
||||||
# Load data from JSON file
|
|
||||||
with open(peft_model_path, 'r') as file:
|
|
||||||
data = json.load(file)
|
|
||||||
# Extract x, y1, and y2 values
|
|
||||||
x = [item['epoch'] for item in data]
|
|
||||||
y1 = [item['learning_rate'] for item in data]
|
|
||||||
y2 = [item['loss'] for item in data]
|
|
||||||
|
|
||||||
# Create the line chart
|
|
||||||
fig, ax1 = plt.subplots(figsize=(10, 6))
|
|
||||||
|
|
||||||
|
|
||||||
# Plot y1 (learning rate) on the first y-axis
|
|
||||||
ax1.plot(x, y1, 'b-', label='Learning Rate')
|
|
||||||
ax1.set_xlabel('Epoch')
|
|
||||||
ax1.set_ylabel('Learning Rate', color='b')
|
|
||||||
ax1.tick_params('y', colors='b')
|
|
||||||
|
|
||||||
# Create a second y-axis
|
|
||||||
ax2 = ax1.twinx()
|
|
||||||
|
|
||||||
# Plot y2 (loss) on the second y-axis
|
|
||||||
ax2.plot(x, y2, 'r-', label='Loss')
|
|
||||||
ax2.set_ylabel('Loss', color='r')
|
|
||||||
ax2.tick_params('y', colors='r')
|
|
||||||
|
|
||||||
# Set the y-axis formatter to display numbers in scientific notation
|
|
||||||
ax1.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
|
|
||||||
ax1.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
|
|
||||||
|
|
||||||
# Add grid
|
|
||||||
ax1.grid(True)
|
|
||||||
|
|
||||||
# Combine the legends for both plots
|
|
||||||
lines, labels = ax1.get_legend_handles_labels()
|
|
||||||
lines2, labels2 = ax2.get_legend_handles_labels()
|
|
||||||
ax2.legend(lines + lines2, labels + labels2, loc='best')
|
|
||||||
|
|
||||||
# Set the title
|
|
||||||
plt.title(f'{lora_name} LR and Loss vs Epoch')
|
|
||||||
|
|
||||||
# Save the chart as an image
|
|
||||||
plt.savefig(image_model_path)
|
|
||||||
|
|
||||||
print(f"Graph saved in {image_model_path}")
|
|
||||||
else:
|
|
||||||
print(f"File 'training_graph.json' does not exist in the {lora_path}")
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
print("matplotlib is not installed. Please install matplotlib to create PNG graphs")
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,368 +0,0 @@
|
||||||
import os
|
|
||||||
from modules import shared, utils
|
|
||||||
from pathlib import Path
|
|
||||||
import requests
|
|
||||||
import tqdm
|
|
||||||
import json
|
|
||||||
|
|
||||||
'''
|
|
||||||
def get_gpu_memory_usage(rank):
|
|
||||||
return {
|
|
||||||
'total': round(torch.cuda.get_device_properties(rank).total_memory / (1024**3), 2),
|
|
||||||
'max': round(torch.cuda.max_memory_allocated(rank) / (1024**3), 2),
|
|
||||||
'reserved': round(torch.cuda.memory_reserved(rank) / (1024**3), 2),
|
|
||||||
'allocated': round(torch.cuda.memory_allocated(rank) / (1024**3), 2)
|
|
||||||
}
|
|
||||||
'''
|
|
||||||
|
|
||||||
def list_subfoldersByTime(directory):
|
|
||||||
|
|
||||||
if not directory.endswith('/'):
|
|
||||||
directory += '/'
|
|
||||||
subfolders = []
|
|
||||||
subfolders.append('None')
|
|
||||||
path = directory
|
|
||||||
name_list = os.listdir(path)
|
|
||||||
full_list = [os.path.join(path,i) for i in name_list]
|
|
||||||
time_sorted_list = sorted(full_list, key=os.path.getmtime,reverse=True)
|
|
||||||
|
|
||||||
for entry in time_sorted_list:
|
|
||||||
if os.path.isdir(entry):
|
|
||||||
entry_str = f"{entry}" # Convert entry to a string
|
|
||||||
full_path = entry_str
|
|
||||||
entry_str = entry_str.replace('\\','/')
|
|
||||||
entry_str = entry_str.replace(f"{directory}", "") # Remove directory part
|
|
||||||
subfolders.append(entry_str)
|
|
||||||
|
|
||||||
return subfolders
|
|
||||||
|
|
||||||
def get_available_loras_local(_sortedByTime):
|
|
||||||
|
|
||||||
model_dir = shared.args.lora_dir # Update with the appropriate directory path
|
|
||||||
subfolders = []
|
|
||||||
if _sortedByTime:
|
|
||||||
subfolders = list_subfoldersByTime(model_dir)
|
|
||||||
else:
|
|
||||||
subfolders = utils.get_available_loras()
|
|
||||||
|
|
||||||
return subfolders
|
|
||||||
|
|
||||||
|
|
||||||
# FPHAM SPLIT BY SENTENCE BLOCK ===============
|
|
||||||
|
|
||||||
def split_sentences(text: str, cutoff_len: int):
|
|
||||||
sentences = []
|
|
||||||
sentence = ''
|
|
||||||
delimiters = ['. ', '? ', '! ', '... ', '.\n', '?\n', '!\n','...\n','</s>','<//>']
|
|
||||||
abbreviations = ['Mr. ', 'Mrs. ', 'Dr. ', 'Ms. ', 'St. ', 'Prof. ', 'Jr. ', 'Ltd. ', 'Capt. ', 'Col. ', 'Gen. ', 'Ave. ', 'Blvd. ', 'Co. ', 'Corp. ', 'Dept. ', 'Est. ', 'Gov. ', 'Inc. ', 'Ph.D. ', 'Univ. ']
|
|
||||||
errors = 0
|
|
||||||
max_cut = cutoff_len-1
|
|
||||||
prev_char = ''
|
|
||||||
|
|
||||||
for char in text:
|
|
||||||
sentence += char
|
|
||||||
|
|
||||||
|
|
||||||
if (any(sentence.endswith(delimiter) for delimiter in delimiters) and
|
|
||||||
not (prev_char.isupper() and len(sentence) >= 3 and sentence[-3] != ' ') and
|
|
||||||
not any(sentence.endswith(abbreviation) for abbreviation in abbreviations)):
|
|
||||||
tokens = shared.tokenizer.encode(sentence)
|
|
||||||
|
|
||||||
if len(tokens) > max_cut:
|
|
||||||
tokens = tokens[:max_cut]
|
|
||||||
sentence = shared.tokenizer.decode(tokens, skip_special_tokens=True)
|
|
||||||
errors = errors + 1
|
|
||||||
|
|
||||||
sentences.append({'text': sentence, 'size': len(tokens)})
|
|
||||||
|
|
||||||
sentence = ''
|
|
||||||
|
|
||||||
prev_char = char
|
|
||||||
|
|
||||||
if sentence:
|
|
||||||
tokens = shared.tokenizer.encode(sentence)
|
|
||||||
if len(tokens) > max_cut:
|
|
||||||
tokens = tokens[:max_cut]
|
|
||||||
sentence = shared.tokenizer.decode(tokens, skip_special_tokens=True)
|
|
||||||
errors = errors + 1
|
|
||||||
|
|
||||||
sentences.append({'text': sentence, 'size': len(tokens)})
|
|
||||||
|
|
||||||
if errors > 0:
|
|
||||||
print(f"Trimmed sentences beyond Cutoff Length: {errors}")
|
|
||||||
|
|
||||||
return sentences
|
|
||||||
|
|
||||||
# The goal of following code is to create blocks of text + overlapping blocks while:
|
|
||||||
# respects sentence boundaries
|
|
||||||
# always uses all the text
|
|
||||||
# hard cut defined by hard_cut_string or </s> will always end at the end of data block
|
|
||||||
# no overlapping blocks will be created across hard cut or across </s> token
|
|
||||||
|
|
||||||
def precise_cut(text: str, overlap: bool, min_chars_cut: int, eos_to_hc: bool, cutoff_len: int, hard_cut_string: str, debug_slicer:bool):
|
|
||||||
|
|
||||||
EOSX_str = '<//>' #hardcut placeholder
|
|
||||||
EOS_str = '</s>'
|
|
||||||
print("Precise raw text slicer: ON")
|
|
||||||
|
|
||||||
cut_string = hard_cut_string.replace('\\n', '\n')
|
|
||||||
text = text.replace(cut_string, EOSX_str)
|
|
||||||
sentences = split_sentences(text, cutoff_len)
|
|
||||||
|
|
||||||
print(f"Sentences: {len(sentences)}")
|
|
||||||
sentencelist = []
|
|
||||||
currentSentence = ''
|
|
||||||
totalLength = 0
|
|
||||||
max_cut = cutoff_len-1
|
|
||||||
half_cut = cutoff_len//2
|
|
||||||
halfcut_length = 0
|
|
||||||
|
|
||||||
edgeindex = []
|
|
||||||
half_index = 0
|
|
||||||
|
|
||||||
for index, item in enumerate(sentences):
|
|
||||||
|
|
||||||
if halfcut_length+ item['size'] < half_cut:
|
|
||||||
halfcut_length += item['size']
|
|
||||||
half_index = index
|
|
||||||
else:
|
|
||||||
edgeindex.append(half_index)
|
|
||||||
halfcut_length = -2 * max_cut
|
|
||||||
|
|
||||||
|
|
||||||
if totalLength + item['size'] < max_cut and not currentSentence.endswith(EOSX_str):
|
|
||||||
currentSentence += item['text']
|
|
||||||
totalLength += item['size']
|
|
||||||
else:
|
|
||||||
|
|
||||||
if len(currentSentence.strip()) > min_chars_cut:
|
|
||||||
sentencelist.append(currentSentence.strip())
|
|
||||||
|
|
||||||
currentSentence = item['text']
|
|
||||||
totalLength = item['size']
|
|
||||||
halfcut_length = item['size']
|
|
||||||
|
|
||||||
if len(currentSentence.strip()) > min_chars_cut:
|
|
||||||
sentencelist.append(currentSentence.strip())
|
|
||||||
|
|
||||||
unique_blocks = len(sentencelist)
|
|
||||||
print(f"Text Blocks: {unique_blocks}")
|
|
||||||
|
|
||||||
#overlap strategies:
|
|
||||||
# don't overlap across HARD CUT (EOSX)
|
|
||||||
if overlap:
|
|
||||||
for edge_idx in edgeindex:
|
|
||||||
currentSentence = ''
|
|
||||||
totalLength = 0
|
|
||||||
|
|
||||||
for item in sentences[edge_idx:]:
|
|
||||||
if totalLength + item['size'] < max_cut:
|
|
||||||
currentSentence += item['text']
|
|
||||||
totalLength += item['size']
|
|
||||||
else:
|
|
||||||
#if by chance EOSX is at the end then it's acceptable
|
|
||||||
if currentSentence.endswith(EOSX_str) and len(currentSentence.strip()) > min_chars_cut:
|
|
||||||
sentencelist.append(currentSentence.strip())
|
|
||||||
# otherwise don't cross hard cut
|
|
||||||
elif EOSX_str not in currentSentence and len(currentSentence.strip()) > min_chars_cut:
|
|
||||||
sentencelist.append(currentSentence.strip())
|
|
||||||
|
|
||||||
currentSentence = ''
|
|
||||||
totalLength = 0
|
|
||||||
break
|
|
||||||
|
|
||||||
print(f"+ Overlapping blocks: {len(sentencelist)-unique_blocks}")
|
|
||||||
|
|
||||||
num_EOS = 0
|
|
||||||
for i in range(len(sentencelist)):
|
|
||||||
if eos_to_hc:
|
|
||||||
sentencelist[i] = sentencelist[i].replace(EOSX_str, EOS_str)
|
|
||||||
else:
|
|
||||||
sentencelist[i] = sentencelist[i].replace(EOSX_str, '')
|
|
||||||
|
|
||||||
#someone may have had stop strings in the raw text...
|
|
||||||
sentencelist[i] = sentencelist[i].replace("</s></s>", EOS_str)
|
|
||||||
num_EOS += sentencelist[i].count(EOS_str)
|
|
||||||
|
|
||||||
if num_EOS > 0:
|
|
||||||
print(f"+ EOS count: {num_EOS}")
|
|
||||||
|
|
||||||
#final check for useless lines
|
|
||||||
sentencelist = [item for item in sentencelist if item.strip() != "</s>"]
|
|
||||||
sentencelist = [item for item in sentencelist if item.strip() != ""]
|
|
||||||
|
|
||||||
|
|
||||||
if debug_slicer:
|
|
||||||
# Write the log file
|
|
||||||
Path('user_data/logs').mkdir(exist_ok=True)
|
|
||||||
sentencelist_dict = {index: sentence for index, sentence in enumerate(sentencelist)}
|
|
||||||
output_file = "user_data/logs/sentencelist.json"
|
|
||||||
with open(output_file, 'w') as f:
|
|
||||||
json.dump(sentencelist_dict, f,indent=2)
|
|
||||||
|
|
||||||
print("Saved sentencelist.json in user_data/logs folder")
|
|
||||||
|
|
||||||
return sentencelist
|
|
||||||
|
|
||||||
|
|
||||||
def sliding_block_cut(text: str, min_chars_cut: int, eos_to_hc: bool, cutoff_len: int, hard_cut_string: str, debug_slicer:bool):
|
|
||||||
|
|
||||||
EOSX_str = '<//>' #hardcut placeholder
|
|
||||||
EOS_str = '</s>'
|
|
||||||
print("Mega Block Overlap: ON")
|
|
||||||
|
|
||||||
cut_string = hard_cut_string.replace('\\n', '\n')
|
|
||||||
text = text.replace(cut_string, EOSX_str)
|
|
||||||
sentences = split_sentences(text, cutoff_len)
|
|
||||||
|
|
||||||
print(f"Sentences: {len(sentences)}")
|
|
||||||
sentencelist = []
|
|
||||||
|
|
||||||
max_cut = cutoff_len-1
|
|
||||||
|
|
||||||
#print(f"max_cut: {max_cut}")
|
|
||||||
advancing_to = 0
|
|
||||||
|
|
||||||
prev_block_lastsentence = ""
|
|
||||||
|
|
||||||
|
|
||||||
for i in range(len(sentences)):
|
|
||||||
totalLength = 0
|
|
||||||
currentSentence = ''
|
|
||||||
lastsentence = ""
|
|
||||||
|
|
||||||
if i >= advancing_to:
|
|
||||||
for k in range(i, len(sentences)):
|
|
||||||
|
|
||||||
current_length = sentences[k]['size']
|
|
||||||
|
|
||||||
if totalLength + current_length <= max_cut and not currentSentence.endswith(EOSX_str):
|
|
||||||
currentSentence += sentences[k]['text']
|
|
||||||
totalLength += current_length
|
|
||||||
lastsentence = sentences[k]['text']
|
|
||||||
else:
|
|
||||||
if len(currentSentence.strip()) > min_chars_cut:
|
|
||||||
if prev_block_lastsentence!=lastsentence:
|
|
||||||
sentencelist.append(currentSentence.strip())
|
|
||||||
prev_block_lastsentence = lastsentence
|
|
||||||
|
|
||||||
advancing_to = 0
|
|
||||||
if currentSentence.endswith(EOSX_str):
|
|
||||||
advancing_to = k
|
|
||||||
|
|
||||||
currentSentence = ""
|
|
||||||
totalLength = 0
|
|
||||||
break
|
|
||||||
|
|
||||||
if currentSentence != "":
|
|
||||||
if len(currentSentence.strip()) > min_chars_cut:
|
|
||||||
sentencelist.append(currentSentence.strip())
|
|
||||||
|
|
||||||
unique_blocks = len(sentencelist)
|
|
||||||
print(f"Text Blocks: {unique_blocks}")
|
|
||||||
num_EOS = 0
|
|
||||||
for i in range(len(sentencelist)):
|
|
||||||
if eos_to_hc:
|
|
||||||
sentencelist[i] = sentencelist[i].replace(EOSX_str, EOS_str)
|
|
||||||
else:
|
|
||||||
sentencelist[i] = sentencelist[i].replace(EOSX_str, '')
|
|
||||||
|
|
||||||
#someone may have had stop strings in the raw text...
|
|
||||||
sentencelist[i] = sentencelist[i].replace("</s></s>", EOS_str)
|
|
||||||
num_EOS += sentencelist[i].count(EOS_str)
|
|
||||||
|
|
||||||
if num_EOS > 0:
|
|
||||||
print(f"+ EOS count: {num_EOS}")
|
|
||||||
|
|
||||||
#final check for useless lines
|
|
||||||
sentencelist = [item for item in sentencelist if item.strip() != "</s>"]
|
|
||||||
sentencelist = [item for item in sentencelist if item.strip() != ""]
|
|
||||||
|
|
||||||
|
|
||||||
if debug_slicer:
|
|
||||||
# Write the log file
|
|
||||||
Path('user_data/logs').mkdir(exist_ok=True)
|
|
||||||
sentencelist_dict = {index: sentence for index, sentence in enumerate(sentencelist)}
|
|
||||||
output_file = "user_data/logs/sentencelist.json"
|
|
||||||
with open(output_file, 'w') as f:
|
|
||||||
json.dump(sentencelist_dict, f,indent=2)
|
|
||||||
|
|
||||||
print("Saved sentencelist.json in user_data/logs folder")
|
|
||||||
|
|
||||||
return sentencelist
|
|
||||||
|
|
||||||
# Example usage:
|
|
||||||
# download_file_from_url('https://example.com/path/to/your/file.ext', '/output/directory')
|
|
||||||
|
|
||||||
def download_file_from_url(url, overwrite, output_dir_in, valid_extensions = {'.txt', '.json'}):
|
|
||||||
try:
|
|
||||||
# Validate and sanitize the URL
|
|
||||||
#parsed_url = urllib.parse.urlparse(url)
|
|
||||||
#if not parsed_url.netloc:
|
|
||||||
# raise ValueError("Invalid URL")
|
|
||||||
#filename = os.path.basename(parsed_url.path)
|
|
||||||
|
|
||||||
# Get the filename from the URL
|
|
||||||
|
|
||||||
session = requests.Session()
|
|
||||||
headers = {}
|
|
||||||
mode = 'wb'
|
|
||||||
filename = url.split('/')[-1]
|
|
||||||
|
|
||||||
output_dir = str(output_dir_in)
|
|
||||||
# Construct the full path to the output file
|
|
||||||
local_filename = os.path.join(output_dir, filename)
|
|
||||||
|
|
||||||
# Check if the local file already exists
|
|
||||||
overw = ''
|
|
||||||
if os.path.exists(local_filename):
|
|
||||||
if not overwrite:
|
|
||||||
yield f"File '{local_filename}' already exists. Aborting."
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
overw = ' [Overwrite existing]'
|
|
||||||
|
|
||||||
filename_lower = filename.lower()
|
|
||||||
|
|
||||||
# Send an HTTP GET request to the URL with a timeout
|
|
||||||
file_extension = os.path.splitext(filename_lower)[-1]
|
|
||||||
|
|
||||||
if file_extension not in valid_extensions:
|
|
||||||
yield f"Invalid file extension: {file_extension}. Only {valid_extensions} files are supported."
|
|
||||||
return
|
|
||||||
|
|
||||||
with session.get(url, stream=True, headers=headers, timeout=10) as r:
|
|
||||||
r.raise_for_status()
|
|
||||||
# total size can be wildly inaccurate
|
|
||||||
#total_size = int(r.headers.get('content-length', 0))
|
|
||||||
|
|
||||||
block_size = 1024 * 4
|
|
||||||
with open(local_filename, mode) as f:
|
|
||||||
count = 0
|
|
||||||
for data in r.iter_content(block_size):
|
|
||||||
f.write(data)
|
|
||||||
count += len(data)
|
|
||||||
|
|
||||||
yield f"Downloaded: {count} " + overw
|
|
||||||
|
|
||||||
# Verify file size if possible
|
|
||||||
if os.path.exists(local_filename):
|
|
||||||
downloaded_size = os.path.getsize(local_filename)
|
|
||||||
if downloaded_size > 0:
|
|
||||||
yield f"File '{filename}' downloaded to '{output_dir}' ({downloaded_size} bytes)."
|
|
||||||
print("File Downloaded")
|
|
||||||
else:
|
|
||||||
print("Downloaded file is zero")
|
|
||||||
yield f"Failed. Downloaded file size is zero)."
|
|
||||||
else:
|
|
||||||
print(f"Error: {local_filename} failed to download.")
|
|
||||||
yield f"Error: {local_filename} failed to download"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred: {e}")
|
|
||||||
yield f"An error occurred: {e}"
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Close the session to release resources
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
|
|
@ -2,6 +2,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
import modules.shared as shared
|
||||||
from modules.html_generator import get_image_cache
|
from modules.html_generator import get_image_cache
|
||||||
from modules.shared import gradio
|
from modules.shared import gradio
|
||||||
|
|
||||||
|
|
@ -72,13 +73,13 @@ def generate_html():
|
||||||
global cards
|
global cards
|
||||||
cards = []
|
cards = []
|
||||||
# Iterate through files in image folder
|
# Iterate through files in image folder
|
||||||
for file in sorted(Path("user_data/characters").glob("*")):
|
for file in sorted((shared.user_data_dir / "characters").glob("*")):
|
||||||
if file.suffix in [".json", ".yml", ".yaml"]:
|
if file.suffix in [".json", ".yml", ".yaml"]:
|
||||||
character = file.stem
|
character = file.stem
|
||||||
container_html = '<div class="character-container">'
|
container_html = '<div class="character-container">'
|
||||||
image_html = "<div class='placeholder'></div>"
|
image_html = "<div class='placeholder'></div>"
|
||||||
|
|
||||||
for path in [Path(f"user_data/characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]:
|
for path in [shared.user_data_dir / "characters" / f"{character}.{extension}" for extension in ['png', 'jpg', 'jpeg']]:
|
||||||
if path.exists():
|
if path.exists():
|
||||||
image_html = f'<img src="file/{get_image_cache(path)}">'
|
image_html = f'<img src="file/{get_image_cache(path)}">'
|
||||||
break
|
break
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,20 @@
|
||||||
import copy
|
import copy
|
||||||
|
import functools
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
import yaml
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from extensions.openai.errors import InvalidRequestError
|
from extensions.openai.errors import InvalidRequestError
|
||||||
from extensions.openai.typing import ToolDefinition
|
from extensions.openai.typing import ToolDefinition
|
||||||
from extensions.openai.utils import debug_msg, getToolCallId, parseToolCall
|
from extensions.openai.utils import debug_msg
|
||||||
|
from modules.tool_parsing import get_tool_call_id, parse_tool_call, detect_tool_call_format
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
from modules.reasoning import extract_reasoning
|
||||||
from modules.chat import (
|
from modules.chat import (
|
||||||
generate_chat_prompt,
|
generate_chat_prompt,
|
||||||
generate_chat_reply,
|
generate_chat_reply,
|
||||||
|
|
@ -22,17 +27,126 @@ from modules.presets import load_preset_memoized
|
||||||
from modules.text_generation import decode, encode, generate_reply
|
from modules.text_generation import decode, encode, generate_reply
|
||||||
|
|
||||||
|
|
||||||
def convert_logprobs_to_tiktoken(model, logprobs):
|
@functools.cache
|
||||||
# more problems than it's worth.
|
def load_chat_template_file(filepath):
|
||||||
# try:
|
"""Load a chat template from a file path (.jinja, .jinja2, or .yaml/.yml)."""
|
||||||
# encoder = tiktoken.encoding_for_model(model)
|
filepath = Path(filepath)
|
||||||
# # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
|
ext = filepath.suffix.lower()
|
||||||
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
|
text = filepath.read_text(encoding='utf-8')
|
||||||
# except KeyError:
|
if ext in ['.yaml', '.yml']:
|
||||||
# # assume native tokens if we can't find the tokenizer
|
data = yaml.safe_load(text)
|
||||||
# return logprobs
|
return data.get('instruction_template', '')
|
||||||
|
return text
|
||||||
|
|
||||||
return logprobs
|
|
||||||
|
def _get_raw_logprob_entries(offset=0):
|
||||||
|
"""Get raw logprob entries from llama.cpp/ExLlamav3 backend, starting from offset.
|
||||||
|
|
||||||
|
Returns (new_entries, new_offset).
|
||||||
|
"""
|
||||||
|
if not hasattr(shared.model, 'last_completion_probabilities') or not shared.model.last_completion_probabilities:
|
||||||
|
return [], offset
|
||||||
|
|
||||||
|
all_entries = shared.model.last_completion_probabilities
|
||||||
|
new_entries = all_entries[offset:]
|
||||||
|
return new_entries, len(all_entries)
|
||||||
|
|
||||||
|
|
||||||
|
def _dict_to_logprob_entries(token_dict):
|
||||||
|
"""Convert a flat {token: logprob} dict (from LogprobProcessor) to raw entry format."""
|
||||||
|
if not token_dict:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [{"top_logprobs": [{"token": t, "logprob": lp} for t, lp in token_dict.items()]}]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_entry_top(entry):
|
||||||
|
"""Extract the top logprobs list from a raw entry, handling both key names."""
|
||||||
|
return entry.get('top_logprobs', entry.get('top_probs', []))
|
||||||
|
|
||||||
|
|
||||||
|
def format_chat_logprobs(entries):
|
||||||
|
"""Format logprob entries into OpenAI chat completions logprobs format.
|
||||||
|
|
||||||
|
Output: {"content": [{"token", "logprob", "bytes", "top_logprobs": [...]}]}
|
||||||
|
"""
|
||||||
|
if not entries:
|
||||||
|
return None
|
||||||
|
|
||||||
|
content = []
|
||||||
|
for entry in entries:
|
||||||
|
top = _parse_entry_top(entry)
|
||||||
|
if not top:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chosen = top[0]
|
||||||
|
token_str = chosen.get('token', '')
|
||||||
|
token_logprob = chosen.get('logprob', chosen.get('prob', 0))
|
||||||
|
|
||||||
|
top_list = []
|
||||||
|
for item in top:
|
||||||
|
t = item.get('token', '')
|
||||||
|
lp = item.get('logprob', item.get('prob', 0))
|
||||||
|
top_list.append({
|
||||||
|
"token": t,
|
||||||
|
"logprob": lp,
|
||||||
|
"bytes": list(t.encode('utf-8')) if t else None
|
||||||
|
})
|
||||||
|
|
||||||
|
content.append({
|
||||||
|
"token": token_str,
|
||||||
|
"logprob": token_logprob,
|
||||||
|
"bytes": list(token_str.encode('utf-8')) if token_str else None,
|
||||||
|
"top_logprobs": top_list
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"content": content, "refusal": None} if content else None
|
||||||
|
|
||||||
|
|
||||||
|
def format_completion_logprobs(entries):
|
||||||
|
"""Format logprob entries into OpenAI completions logprobs format.
|
||||||
|
|
||||||
|
Output: {"tokens", "token_logprobs", "top_logprobs": [{token: prob}], "text_offset"}
|
||||||
|
"""
|
||||||
|
if not entries:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
token_logprobs = []
|
||||||
|
top_logprobs = []
|
||||||
|
text_offset = []
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
top = _parse_entry_top(entry)
|
||||||
|
if not top:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chosen = top[0]
|
||||||
|
token_str = chosen.get('token', '')
|
||||||
|
token_logprob = chosen.get('logprob', chosen.get('prob', 0))
|
||||||
|
|
||||||
|
tokens.append(token_str)
|
||||||
|
token_logprobs.append(token_logprob)
|
||||||
|
text_offset.append(offset)
|
||||||
|
offset += len(token_str)
|
||||||
|
|
||||||
|
top_dict = {}
|
||||||
|
for item in top:
|
||||||
|
t = item.get('token', '')
|
||||||
|
lp = item.get('logprob', item.get('prob', 0))
|
||||||
|
top_dict[t] = lp
|
||||||
|
top_logprobs.append(top_dict)
|
||||||
|
|
||||||
|
if not tokens:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"tokens": tokens,
|
||||||
|
"token_logprobs": token_logprobs,
|
||||||
|
"top_logprobs": top_logprobs,
|
||||||
|
"text_offset": text_offset
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def process_parameters(body, is_legacy=False):
|
def process_parameters(body, is_legacy=False):
|
||||||
|
|
@ -57,7 +171,16 @@ def process_parameters(body, is_legacy=False):
|
||||||
elif isinstance(body['stop'], list):
|
elif isinstance(body['stop'], list):
|
||||||
generate_params['custom_stopping_strings'] = body['stop']
|
generate_params['custom_stopping_strings'] = body['stop']
|
||||||
|
|
||||||
if shared.args.loader != 'llama.cpp':
|
# Resolve logprobs: for chat completions, logprobs is a bool and the count
|
||||||
|
# comes from top_logprobs. Normalize to an int for all backends.
|
||||||
|
logprobs = body.get('logprobs', None)
|
||||||
|
top_logprobs = body.get('top_logprobs', None)
|
||||||
|
if logprobs is True:
|
||||||
|
logprobs = max(top_logprobs, 1) if top_logprobs is not None else 5
|
||||||
|
generate_params['logprobs'] = logprobs
|
||||||
|
|
||||||
|
# For llama.cpp and ExLlamav3 native, logit_bias and logprobs are forwarded natively
|
||||||
|
if shared.args.loader not in ('llama.cpp', 'ExLlamav3'):
|
||||||
from transformers import LogitsProcessorList
|
from transformers import LogitsProcessorList
|
||||||
|
|
||||||
from modules.transformers_loader import (
|
from modules.transformers_loader import (
|
||||||
|
|
@ -70,13 +193,9 @@ def process_parameters(body, is_legacy=False):
|
||||||
if logit_bias: # {str: float, ...}
|
if logit_bias: # {str: float, ...}
|
||||||
logits_processor = [LogitsBiasProcessor(logit_bias)]
|
logits_processor = [LogitsBiasProcessor(logit_bias)]
|
||||||
|
|
||||||
logprobs = None # coming to chat eventually
|
if logprobs is not None and logprobs > 0:
|
||||||
if 'logprobs' in body:
|
|
||||||
logprobs = body.get('logprobs', 0) # maybe cap at topk? don't clamp 0-5.
|
|
||||||
generate_params['logprob_proc'] = LogprobProcessor(logprobs)
|
generate_params['logprob_proc'] = LogprobProcessor(logprobs)
|
||||||
logits_processor.extend([generate_params['logprob_proc']])
|
logits_processor.extend([generate_params['logprob_proc']])
|
||||||
else:
|
|
||||||
logprobs = None
|
|
||||||
|
|
||||||
if logits_processor: # requires logits_processor support
|
if logits_processor: # requires logits_processor support
|
||||||
generate_params['logits_processor'] = LogitsProcessorList(logits_processor)
|
generate_params['logits_processor'] = LogitsProcessorList(logits_processor)
|
||||||
|
|
@ -122,38 +241,58 @@ def convert_history(history):
|
||||||
user_input = ""
|
user_input = ""
|
||||||
user_input_last = True
|
user_input_last = True
|
||||||
system_message = ""
|
system_message = ""
|
||||||
|
seen_non_system = False
|
||||||
|
|
||||||
for entry in history:
|
for entry in history:
|
||||||
content = entry["content"]
|
content = entry["content"]
|
||||||
role = entry["role"]
|
role = entry["role"]
|
||||||
|
|
||||||
if role == "user":
|
if role == "user":
|
||||||
|
seen_non_system = True
|
||||||
# Extract text content (images handled by model-specific code)
|
# Extract text content (images handled by model-specific code)
|
||||||
content = process_multimodal_content(content)
|
content = process_multimodal_content(content)
|
||||||
user_input = content
|
user_input = content
|
||||||
user_input_last = True
|
user_input_last = True
|
||||||
|
|
||||||
if current_message:
|
if current_message:
|
||||||
chat_dialogue.append([current_message, '', ''])
|
chat_dialogue.append([current_message, '', '', {}])
|
||||||
current_message = ""
|
current_message = ""
|
||||||
|
|
||||||
current_message = content
|
current_message = content
|
||||||
elif role == "assistant":
|
elif role == "assistant":
|
||||||
if "tool_calls" in entry and isinstance(entry["tool_calls"], list) and len(entry["tool_calls"]) > 0 and content.strip() == "":
|
seen_non_system = True
|
||||||
continue # skip tool calls
|
meta = {}
|
||||||
|
tool_calls = entry.get("tool_calls")
|
||||||
|
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
|
||||||
|
meta["tool_calls"] = tool_calls
|
||||||
|
if content.strip() == "":
|
||||||
|
content = "" # keep empty content, don't skip
|
||||||
|
|
||||||
current_reply = content
|
current_reply = content
|
||||||
user_input_last = False
|
user_input_last = False
|
||||||
if current_message:
|
if current_message:
|
||||||
chat_dialogue.append([current_message, current_reply, ''])
|
chat_dialogue.append([current_message, current_reply, '', meta])
|
||||||
current_message = ""
|
current_message = ""
|
||||||
current_reply = ""
|
current_reply = ""
|
||||||
else:
|
else:
|
||||||
chat_dialogue.append(['', current_reply, ''])
|
chat_dialogue.append(['', current_reply, '', meta])
|
||||||
elif role == "tool":
|
elif role == "tool":
|
||||||
|
seen_non_system = True
|
||||||
user_input_last = False
|
user_input_last = False
|
||||||
chat_dialogue.append(['', '', content])
|
meta = {}
|
||||||
elif role == "system":
|
if "tool_call_id" in entry:
|
||||||
|
meta["tool_call_id"] = entry["tool_call_id"]
|
||||||
|
chat_dialogue.append(['', '', content, meta])
|
||||||
|
elif role in ("system", "developer"):
|
||||||
|
if not seen_non_system:
|
||||||
|
# Leading system messages go to custom_system_message (placed at top)
|
||||||
system_message += f"\n{content}" if system_message else content
|
system_message += f"\n{content}" if system_message else content
|
||||||
|
else:
|
||||||
|
# Mid-conversation system messages: preserve position in history
|
||||||
|
if current_message:
|
||||||
|
chat_dialogue.append([current_message, '', '', {}])
|
||||||
|
current_message = ""
|
||||||
|
chat_dialogue.append([content, '', '', {"role": "system"}])
|
||||||
|
|
||||||
if not user_input_last:
|
if not user_input_last:
|
||||||
user_input = ""
|
user_input = ""
|
||||||
|
|
@ -165,7 +304,7 @@ def convert_history(history):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, prompt_only=False) -> dict:
|
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, prompt_only=False, stop_event=None) -> dict:
|
||||||
if body.get('functions', []):
|
if body.get('functions', []):
|
||||||
raise InvalidRequestError(message="functions is not supported.", param='functions')
|
raise InvalidRequestError(message="functions is not supported.", param='functions')
|
||||||
|
|
||||||
|
|
@ -179,6 +318,10 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
||||||
if 'tools' in body and body['tools'] is not None and isinstance(body['tools'], list) and len(body['tools']) > 0:
|
if 'tools' in body and body['tools'] is not None and isinstance(body['tools'], list) and len(body['tools']) > 0:
|
||||||
tools = validateTools(body['tools']) # raises InvalidRequestError if validation fails
|
tools = validateTools(body['tools']) # raises InvalidRequestError if validation fails
|
||||||
|
|
||||||
|
tool_choice = body.get('tool_choice', None)
|
||||||
|
if tool_choice == "none":
|
||||||
|
tools = None # Disable tool detection entirely
|
||||||
|
|
||||||
messages = body['messages']
|
messages = body['messages']
|
||||||
for m in messages:
|
for m in messages:
|
||||||
if 'role' not in m:
|
if 'role' not in m:
|
||||||
|
|
@ -189,6 +332,10 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
||||||
# Handle multimodal content validation
|
# Handle multimodal content validation
|
||||||
content = m.get('content')
|
content = m.get('content')
|
||||||
if content is None:
|
if content is None:
|
||||||
|
# OpenAI allows content: null on assistant messages when tool_calls is present
|
||||||
|
if m['role'] == 'assistant' and m.get('tool_calls'):
|
||||||
|
m['content'] = ''
|
||||||
|
else:
|
||||||
raise InvalidRequestError(message="messages: missing content", param='messages')
|
raise InvalidRequestError(message="messages: missing content", param='messages')
|
||||||
|
|
||||||
# Validate multimodal content structure
|
# Validate multimodal content structure
|
||||||
|
|
@ -211,6 +358,8 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
||||||
|
|
||||||
# generation parameters
|
# generation parameters
|
||||||
generate_params = process_parameters(body, is_legacy=is_legacy)
|
generate_params = process_parameters(body, is_legacy=is_legacy)
|
||||||
|
if stop_event is not None:
|
||||||
|
generate_params['stop_event'] = stop_event
|
||||||
continue_ = body['continue_']
|
continue_ = body['continue_']
|
||||||
|
|
||||||
# Instruction template
|
# Instruction template
|
||||||
|
|
@ -220,6 +369,8 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
||||||
instruction_template = body['instruction_template']
|
instruction_template = body['instruction_template']
|
||||||
instruction_template = "Alpaca" if instruction_template == "None" else instruction_template
|
instruction_template = "Alpaca" if instruction_template == "None" else instruction_template
|
||||||
instruction_template_str = load_instruction_template_memoized(instruction_template)
|
instruction_template_str = load_instruction_template_memoized(instruction_template)
|
||||||
|
elif shared.args.chat_template_file:
|
||||||
|
instruction_template_str = load_chat_template_file(shared.args.chat_template_file)
|
||||||
else:
|
else:
|
||||||
instruction_template_str = shared.settings['instruction_template_str']
|
instruction_template_str = shared.settings['instruction_template_str']
|
||||||
|
|
||||||
|
|
@ -262,106 +413,189 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
||||||
|
|
||||||
requested_model = generate_params.pop('model')
|
requested_model = generate_params.pop('model')
|
||||||
logprob_proc = generate_params.pop('logprob_proc', None)
|
logprob_proc = generate_params.pop('logprob_proc', None)
|
||||||
|
if logprob_proc:
|
||||||
|
logprob_proc.token_alternatives_history.clear()
|
||||||
|
chat_logprobs_offset = [0] # mutable for closure access in streaming
|
||||||
|
|
||||||
def chat_streaming_chunk(content, chunk_tool_calls=None):
|
def chat_streaming_chunk(content=None, chunk_tool_calls=None, include_role=False, reasoning_content=None):
|
||||||
# begin streaming
|
# begin streaming
|
||||||
|
delta = {}
|
||||||
|
if include_role:
|
||||||
|
delta['role'] = 'assistant'
|
||||||
|
delta['refusal'] = None
|
||||||
|
if content is not None:
|
||||||
|
delta['content'] = content
|
||||||
|
if reasoning_content is not None:
|
||||||
|
delta['reasoning_content'] = reasoning_content
|
||||||
|
if chunk_tool_calls:
|
||||||
|
delta['tool_calls'] = chunk_tool_calls
|
||||||
|
|
||||||
chunk = {
|
chunk = {
|
||||||
"id": cmpl_id,
|
"id": cmpl_id,
|
||||||
"object": object_type,
|
"object": object_type,
|
||||||
"created": created_time,
|
"created": created_time,
|
||||||
"model": shared.model_name,
|
"model": shared.model_name,
|
||||||
|
"system_fingerprint": None,
|
||||||
resp_list: [{
|
resp_list: [{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"finish_reason": None,
|
"finish_reason": None,
|
||||||
"delta": {'role': 'assistant', 'content': content, 'tool_calls': chunk_tool_calls},
|
"delta": delta,
|
||||||
|
"logprobs": None,
|
||||||
}],
|
}],
|
||||||
}
|
}
|
||||||
|
|
||||||
if logprob_proc: # not official for chat yet
|
if logprob_proc:
|
||||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
entries = _dict_to_logprob_entries(logprob_proc.token_alternatives)
|
||||||
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
formatted = format_chat_logprobs(entries)
|
||||||
# else:
|
if formatted:
|
||||||
# chunk[resp_list][0]["logprobs"] = None
|
chunk[resp_list][0]["logprobs"] = formatted
|
||||||
|
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
|
||||||
|
entries, chat_logprobs_offset[0] = _get_raw_logprob_entries(chat_logprobs_offset[0])
|
||||||
|
if entries:
|
||||||
|
formatted = format_chat_logprobs(entries)
|
||||||
|
if formatted:
|
||||||
|
chunk[resp_list][0]["logprobs"] = formatted
|
||||||
|
|
||||||
return chunk
|
return chunk
|
||||||
|
|
||||||
|
# Check if usage should be included in streaming chunks per OpenAI spec
|
||||||
|
stream_options = body.get('stream_options')
|
||||||
|
include_usage = bool(stream_options) and bool(stream_options.get('include_usage') if isinstance(stream_options, dict) else getattr(stream_options, 'include_usage', False))
|
||||||
|
|
||||||
# generate reply #######################################
|
# generate reply #######################################
|
||||||
prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_)
|
|
||||||
if prompt_only:
|
if prompt_only:
|
||||||
|
prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_)
|
||||||
yield {'prompt': prompt}
|
yield {'prompt': prompt}
|
||||||
return
|
return
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
yield chat_streaming_chunk('')
|
chunk = chat_streaming_chunk('', include_role=True)
|
||||||
|
if include_usage:
|
||||||
|
chunk['usage'] = None
|
||||||
|
yield chunk
|
||||||
|
|
||||||
generator = generate_chat_reply(
|
generator = generate_chat_reply(
|
||||||
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
|
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
|
||||||
|
|
||||||
answer = ''
|
answer = ''
|
||||||
seen_content = ''
|
seen_content = ''
|
||||||
|
seen_reasoning = ''
|
||||||
|
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
end_last_tool_call = 0
|
end_last_tool_call = 0
|
||||||
supported_tools = [x["function"]["name"] for x in tools] if tools is not None else None
|
supported_tools = [x["function"]["name"] for x in tools] if tools is not None else None
|
||||||
|
_tool_parsers = None
|
||||||
|
|
||||||
|
# Filter supported_tools when tool_choice specifies a particular function
|
||||||
|
if supported_tools and isinstance(tool_choice, dict):
|
||||||
|
specified_func = tool_choice.get("function", {}).get("name")
|
||||||
|
if specified_func and specified_func in supported_tools:
|
||||||
|
supported_tools = [specified_func]
|
||||||
|
|
||||||
|
if supported_tools is not None:
|
||||||
|
_template_str = generate_params.get('instruction_template_str', '') if generate_params.get('mode') == 'instruct' else generate_params.get('chat_template_str', '')
|
||||||
|
_tool_parsers, _, _ = detect_tool_call_format(_template_str)
|
||||||
|
|
||||||
for a in generator:
|
for a in generator:
|
||||||
answer = a['internal'][-1][1]
|
answer = a['internal'][-1][1]
|
||||||
|
|
||||||
if supported_tools is not None:
|
if supported_tools is not None:
|
||||||
tool_call = parseToolCall(answer[end_last_tool_call:], supported_tools) if len(answer) > 0 else []
|
tool_call = parse_tool_call(answer[end_last_tool_call:], supported_tools, parsers=_tool_parsers) if len(answer) > 0 else []
|
||||||
if len(tool_call) > 0:
|
if len(tool_call) > 0:
|
||||||
for tc in tool_call:
|
for tc in tool_call:
|
||||||
tc["id"] = getToolCallId()
|
tc["id"] = get_tool_call_id()
|
||||||
tc["index"] = str(len(tool_calls))
|
if stream:
|
||||||
|
tc["index"] = len(tool_calls)
|
||||||
tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"])
|
tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"])
|
||||||
tool_calls.append(tc)
|
tool_calls.append(tc)
|
||||||
end_last_tool_call = len(answer)
|
end_last_tool_call = len(answer)
|
||||||
|
|
||||||
if stream:
|
# Stop generation before streaming content if tool_calls were detected,
|
||||||
len_seen = len(seen_content)
|
# so that raw tool markup is not sent as content deltas.
|
||||||
new_content = answer[len_seen:]
|
|
||||||
|
|
||||||
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
|
||||||
continue
|
|
||||||
|
|
||||||
chunk = chat_streaming_chunk(new_content)
|
|
||||||
|
|
||||||
seen_content = answer
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
# stop generation if tool_calls were generated previously
|
|
||||||
if len(tool_calls) > 0:
|
if len(tool_calls) > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
token_count = len(encode(prompt)[0])
|
if stream:
|
||||||
|
# Strip reasoning/thinking blocks so only final content is streamed.
|
||||||
|
# Reasoning is emitted separately as reasoning_content deltas.
|
||||||
|
reasoning, content = extract_reasoning(answer)
|
||||||
|
if reasoning is not None:
|
||||||
|
new_reasoning = reasoning[len(seen_reasoning):]
|
||||||
|
new_content = content[len(seen_content):]
|
||||||
|
else:
|
||||||
|
new_reasoning = None
|
||||||
|
new_content = answer[len(seen_content):]
|
||||||
|
|
||||||
|
if (not new_content and not new_reasoning) or chr(0xfffd) in (new_content or '') + (new_reasoning or ''):
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk = chat_streaming_chunk(
|
||||||
|
content=new_content if new_content else None,
|
||||||
|
reasoning_content=new_reasoning if new_reasoning else None,
|
||||||
|
)
|
||||||
|
if include_usage:
|
||||||
|
chunk['usage'] = None
|
||||||
|
|
||||||
|
if reasoning is not None:
|
||||||
|
seen_reasoning = reasoning
|
||||||
|
seen_content = content
|
||||||
|
else:
|
||||||
|
seen_content = answer
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
token_count = shared.model.last_prompt_token_count if hasattr(shared.model, 'last_prompt_token_count') else 0
|
||||||
completion_token_count = len(encode(answer)[0])
|
completion_token_count = len(encode(answer)[0])
|
||||||
stop_reason = "stop"
|
|
||||||
if len(tool_calls) > 0:
|
if len(tool_calls) > 0:
|
||||||
stop_reason = "tool_calls"
|
stop_reason = "tool_calls"
|
||||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
|
elif token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
|
||||||
stop_reason = "length"
|
stop_reason = "length"
|
||||||
|
else:
|
||||||
|
stop_reason = "stop"
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
chunk = chat_streaming_chunk('', tool_calls)
|
chunk = chat_streaming_chunk(chunk_tool_calls=tool_calls)
|
||||||
chunk[resp_list][0]['finish_reason'] = stop_reason
|
chunk[resp_list][0]['finish_reason'] = stop_reason
|
||||||
chunk['usage'] = {
|
usage = {
|
||||||
"prompt_tokens": token_count,
|
"prompt_tokens": token_count,
|
||||||
"completion_tokens": completion_token_count,
|
"completion_tokens": completion_token_count,
|
||||||
"total_tokens": token_count + completion_token_count
|
"total_tokens": token_count + completion_token_count
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if include_usage:
|
||||||
|
chunk['usage'] = None
|
||||||
|
yield chunk
|
||||||
|
# Separate usage-only chunk with choices: [] per OpenAI spec
|
||||||
|
yield {
|
||||||
|
"id": cmpl_id,
|
||||||
|
"object": object_type,
|
||||||
|
"created": created_time,
|
||||||
|
"model": shared.model_name,
|
||||||
|
"system_fingerprint": None,
|
||||||
|
resp_list: [],
|
||||||
|
"usage": usage
|
||||||
|
}
|
||||||
|
else:
|
||||||
yield chunk
|
yield chunk
|
||||||
else:
|
else:
|
||||||
|
reasoning, content = extract_reasoning(answer)
|
||||||
|
message = {
|
||||||
|
"role": "assistant",
|
||||||
|
"refusal": None,
|
||||||
|
"content": None if tool_calls else content,
|
||||||
|
**({"reasoning_content": reasoning} if reasoning else {}),
|
||||||
|
**({"tool_calls": tool_calls} if tool_calls else {}),
|
||||||
|
}
|
||||||
resp = {
|
resp = {
|
||||||
"id": cmpl_id,
|
"id": cmpl_id,
|
||||||
"object": object_type,
|
"object": object_type,
|
||||||
"created": created_time,
|
"created": created_time,
|
||||||
"model": shared.model_name,
|
"model": shared.model_name,
|
||||||
|
"system_fingerprint": None,
|
||||||
resp_list: [{
|
resp_list: [{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"finish_reason": stop_reason,
|
"finish_reason": stop_reason,
|
||||||
"message": {"role": "assistant", "content": answer},
|
"message": message,
|
||||||
"tool_calls": tool_calls
|
"logprobs": None,
|
||||||
}],
|
}],
|
||||||
"usage": {
|
"usage": {
|
||||||
"prompt_tokens": token_count,
|
"prompt_tokens": token_count,
|
||||||
|
|
@ -369,19 +603,27 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
||||||
"total_tokens": token_count + completion_token_count
|
"total_tokens": token_count + completion_token_count
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if logprob_proc: # not official for chat yet
|
if logprob_proc:
|
||||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
all_entries = []
|
||||||
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
for alt in logprob_proc.token_alternatives_history:
|
||||||
# else:
|
all_entries.extend(_dict_to_logprob_entries(alt))
|
||||||
# resp[resp_list][0]["logprobs"] = None
|
formatted = format_chat_logprobs(all_entries)
|
||||||
|
if formatted:
|
||||||
|
resp[resp_list][0]["logprobs"] = formatted
|
||||||
|
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
|
||||||
|
raw = getattr(shared.model, 'last_completion_probabilities', None)
|
||||||
|
if raw:
|
||||||
|
formatted = format_chat_logprobs(raw)
|
||||||
|
if formatted:
|
||||||
|
resp[resp_list][0]["logprobs"] = formatted
|
||||||
|
|
||||||
yield resp
|
yield resp
|
||||||
|
|
||||||
|
|
||||||
def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_event=None):
|
||||||
object_type = 'text_completion.chunk' if stream else 'text_completion'
|
object_type = 'text_completion'
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
|
cmpl_id = "cmpl-%d" % (int(time.time() * 1000000000))
|
||||||
resp_list = 'data' if is_legacy else 'choices'
|
resp_list = 'data' if is_legacy else 'choices'
|
||||||
|
|
||||||
prompt_str = 'context' if is_legacy else 'prompt'
|
prompt_str = 'context' if is_legacy else 'prompt'
|
||||||
|
|
@ -411,8 +653,12 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
||||||
generate_params = process_parameters(body, is_legacy=is_legacy)
|
generate_params = process_parameters(body, is_legacy=is_legacy)
|
||||||
max_tokens = generate_params['max_new_tokens']
|
max_tokens = generate_params['max_new_tokens']
|
||||||
generate_params['stream'] = stream
|
generate_params['stream'] = stream
|
||||||
|
if stop_event is not None:
|
||||||
|
generate_params['stop_event'] = stop_event
|
||||||
requested_model = generate_params.pop('model')
|
requested_model = generate_params.pop('model')
|
||||||
logprob_proc = generate_params.pop('logprob_proc', None)
|
logprob_proc = generate_params.pop('logprob_proc', None)
|
||||||
|
if logprob_proc:
|
||||||
|
logprob_proc.token_alternatives_history.clear()
|
||||||
suffix = body['suffix'] if body['suffix'] else ''
|
suffix = body['suffix'] if body['suffix'] else ''
|
||||||
echo = body['echo']
|
echo = body['echo']
|
||||||
|
|
||||||
|
|
@ -424,6 +670,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
||||||
logger.info(f"Found {len(raw_images)} image(s) in request.")
|
logger.info(f"Found {len(raw_images)} image(s) in request.")
|
||||||
generate_params['raw_images'] = raw_images
|
generate_params['raw_images'] = raw_images
|
||||||
|
|
||||||
|
n_completions = body.get('n', 1) or 1
|
||||||
|
|
||||||
if not stream:
|
if not stream:
|
||||||
prompt_arg = body[prompt_str]
|
prompt_arg = body[prompt_str]
|
||||||
|
|
||||||
|
|
@ -437,6 +685,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
||||||
resp_list_data = []
|
resp_list_data = []
|
||||||
total_completion_token_count = 0
|
total_completion_token_count = 0
|
||||||
total_prompt_token_count = 0
|
total_prompt_token_count = 0
|
||||||
|
choice_index = 0
|
||||||
|
|
||||||
for idx, prompt in enumerate(prompt_arg, start=0):
|
for idx, prompt in enumerate(prompt_arg, start=0):
|
||||||
if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], int):
|
if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], int):
|
||||||
|
|
@ -451,6 +700,17 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
||||||
prompt = decode(prompt)[0]
|
prompt = decode(prompt)[0]
|
||||||
|
|
||||||
prefix = prompt if echo else ''
|
prefix = prompt if echo else ''
|
||||||
|
token_count = len(encode(prompt)[0])
|
||||||
|
total_prompt_token_count += token_count
|
||||||
|
|
||||||
|
original_seed = generate_params.get('seed', -1)
|
||||||
|
for _n in range(n_completions):
|
||||||
|
# Increment seed for each completion to ensure diversity (matches llama.cpp native behavior)
|
||||||
|
if original_seed >= 0:
|
||||||
|
generate_params['seed'] = original_seed + _n
|
||||||
|
|
||||||
|
if logprob_proc:
|
||||||
|
logprob_proc.token_alternatives_history.clear()
|
||||||
|
|
||||||
# generate reply #######################################
|
# generate reply #######################################
|
||||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||||
|
|
@ -460,28 +720,39 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
||||||
for a in generator:
|
for a in generator:
|
||||||
answer = a
|
answer = a
|
||||||
|
|
||||||
token_count = len(encode(prompt)[0])
|
|
||||||
total_prompt_token_count += token_count
|
|
||||||
completion_token_count = len(encode(answer)[0])
|
completion_token_count = len(encode(answer)[0])
|
||||||
total_completion_token_count += completion_token_count
|
total_completion_token_count += completion_token_count
|
||||||
stop_reason = "stop"
|
stop_reason = "stop"
|
||||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||||
stop_reason = "length"
|
stop_reason = "length"
|
||||||
|
|
||||||
|
if logprob_proc:
|
||||||
|
all_entries = []
|
||||||
|
for alt in logprob_proc.token_alternatives_history:
|
||||||
|
all_entries.extend(_dict_to_logprob_entries(alt))
|
||||||
|
completion_logprobs = format_completion_logprobs(all_entries)
|
||||||
|
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
|
||||||
|
raw = getattr(shared.model, 'last_completion_probabilities', None)
|
||||||
|
completion_logprobs = format_completion_logprobs(raw)
|
||||||
|
else:
|
||||||
|
completion_logprobs = None
|
||||||
|
|
||||||
respi = {
|
respi = {
|
||||||
"index": idx,
|
"index": choice_index,
|
||||||
"finish_reason": stop_reason,
|
"finish_reason": stop_reason,
|
||||||
"text": prefix + answer + suffix,
|
"text": prefix + answer + suffix,
|
||||||
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
"logprobs": completion_logprobs,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp_list_data.extend([respi])
|
resp_list_data.append(respi)
|
||||||
|
choice_index += 1
|
||||||
|
|
||||||
resp = {
|
resp = {
|
||||||
"id": cmpl_id,
|
"id": cmpl_id,
|
||||||
"object": object_type,
|
"object": object_type,
|
||||||
"created": created_time,
|
"created": created_time,
|
||||||
"model": shared.model_name,
|
"model": shared.model_name,
|
||||||
|
"system_fingerprint": None,
|
||||||
resp_list: resp_list_data,
|
resp_list: resp_list_data,
|
||||||
"usage": {
|
"usage": {
|
||||||
"prompt_tokens": total_prompt_token_count,
|
"prompt_tokens": total_prompt_token_count,
|
||||||
|
|
@ -506,24 +777,41 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
||||||
prefix = prompt if echo else ''
|
prefix = prompt if echo else ''
|
||||||
token_count = len(encode(prompt)[0])
|
token_count = len(encode(prompt)[0])
|
||||||
|
|
||||||
|
# Check if usage should be included in streaming chunks per OpenAI spec
|
||||||
|
stream_options = body.get('stream_options')
|
||||||
|
include_usage = bool(stream_options) and bool(stream_options.get('include_usage') if isinstance(stream_options, dict) else getattr(stream_options, 'include_usage', False))
|
||||||
|
cmpl_logprobs_offset = [0] # mutable for closure access in streaming
|
||||||
|
|
||||||
def text_streaming_chunk(content):
|
def text_streaming_chunk(content):
|
||||||
# begin streaming
|
# begin streaming
|
||||||
|
if logprob_proc:
|
||||||
|
chunk_logprobs = format_completion_logprobs(_dict_to_logprob_entries(logprob_proc.token_alternatives))
|
||||||
|
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
|
||||||
|
entries, cmpl_logprobs_offset[0] = _get_raw_logprob_entries(cmpl_logprobs_offset[0])
|
||||||
|
chunk_logprobs = format_completion_logprobs(entries) if entries else None
|
||||||
|
else:
|
||||||
|
chunk_logprobs = None
|
||||||
|
|
||||||
chunk = {
|
chunk = {
|
||||||
"id": cmpl_id,
|
"id": cmpl_id,
|
||||||
"object": object_type,
|
"object": object_type,
|
||||||
"created": created_time,
|
"created": created_time,
|
||||||
"model": shared.model_name,
|
"model": shared.model_name,
|
||||||
|
"system_fingerprint": None,
|
||||||
resp_list: [{
|
resp_list: [{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"finish_reason": None,
|
"finish_reason": None,
|
||||||
"text": content,
|
"text": content,
|
||||||
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
"logprobs": chunk_logprobs,
|
||||||
}],
|
}],
|
||||||
}
|
}
|
||||||
|
|
||||||
return chunk
|
return chunk
|
||||||
|
|
||||||
yield text_streaming_chunk(prefix)
|
chunk = text_streaming_chunk(prefix)
|
||||||
|
if include_usage:
|
||||||
|
chunk['usage'] = None
|
||||||
|
yield chunk
|
||||||
|
|
||||||
# generate reply #######################################
|
# generate reply #######################################
|
||||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||||
|
|
@ -543,6 +831,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
||||||
|
|
||||||
seen_content = answer
|
seen_content = answer
|
||||||
chunk = text_streaming_chunk(new_content)
|
chunk = text_streaming_chunk(new_content)
|
||||||
|
if include_usage:
|
||||||
|
chunk['usage'] = None
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
completion_token_count = len(encode(answer)[0])
|
completion_token_count = len(encode(answer)[0])
|
||||||
|
|
@ -552,32 +842,46 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
||||||
|
|
||||||
chunk = text_streaming_chunk(suffix)
|
chunk = text_streaming_chunk(suffix)
|
||||||
chunk[resp_list][0]["finish_reason"] = stop_reason
|
chunk[resp_list][0]["finish_reason"] = stop_reason
|
||||||
chunk["usage"] = {
|
usage = {
|
||||||
"prompt_tokens": token_count,
|
"prompt_tokens": token_count,
|
||||||
"completion_tokens": completion_token_count,
|
"completion_tokens": completion_token_count,
|
||||||
"total_tokens": token_count + completion_token_count
|
"total_tokens": token_count + completion_token_count
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if include_usage:
|
||||||
|
chunk['usage'] = None
|
||||||
|
yield chunk
|
||||||
|
# Separate usage-only chunk with choices: [] per OpenAI spec
|
||||||
|
yield {
|
||||||
|
"id": cmpl_id,
|
||||||
|
"object": object_type,
|
||||||
|
"created": created_time,
|
||||||
|
"model": shared.model_name,
|
||||||
|
"system_fingerprint": None,
|
||||||
|
resp_list: [],
|
||||||
|
"usage": usage
|
||||||
|
}
|
||||||
|
else:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
def chat_completions(body: dict, is_legacy: bool = False) -> dict:
|
def chat_completions(body: dict, is_legacy: bool = False, stop_event=None) -> dict:
|
||||||
generator = chat_completions_common(body, is_legacy, stream=False)
|
generator = chat_completions_common(body, is_legacy, stream=False, stop_event=stop_event)
|
||||||
return deque(generator, maxlen=1).pop()
|
return deque(generator, maxlen=1).pop()
|
||||||
|
|
||||||
|
|
||||||
def stream_chat_completions(body: dict, is_legacy: bool = False):
|
def stream_chat_completions(body: dict, is_legacy: bool = False, stop_event=None):
|
||||||
for resp in chat_completions_common(body, is_legacy, stream=True):
|
for resp in chat_completions_common(body, is_legacy, stream=True, stop_event=stop_event):
|
||||||
yield resp
|
yield resp
|
||||||
|
|
||||||
|
|
||||||
def completions(body: dict, is_legacy: bool = False) -> dict:
|
def completions(body: dict, is_legacy: bool = False, stop_event=None) -> dict:
|
||||||
generator = completions_common(body, is_legacy, stream=False)
|
generator = completions_common(body, is_legacy, stream=False, stop_event=stop_event)
|
||||||
return deque(generator, maxlen=1).pop()
|
return deque(generator, maxlen=1).pop()
|
||||||
|
|
||||||
|
|
||||||
def stream_completions(body: dict, is_legacy: bool = False):
|
def stream_completions(body: dict, is_legacy: bool = False, stop_event=None):
|
||||||
for resp in completions_common(body, is_legacy, stream=True):
|
for resp in completions_common(body, is_legacy, stream=True, stop_event=stop_event):
|
||||||
yield resp
|
yield resp
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -588,6 +892,12 @@ def validateTools(tools: list[dict]):
|
||||||
tool = tools[idx]
|
tool = tools[idx]
|
||||||
try:
|
try:
|
||||||
tool_definition = ToolDefinition(**tool)
|
tool_definition = ToolDefinition(**tool)
|
||||||
|
# Backfill defaults so Jinja2 templates don't crash on missing fields
|
||||||
|
func = tool.get("function", {})
|
||||||
|
if "description" not in func:
|
||||||
|
func["description"] = ""
|
||||||
|
if "parameters" not in func:
|
||||||
|
func["parameters"] = {"type": "object", "properties": {}}
|
||||||
if valid_tools is None:
|
if valid_tools is None:
|
||||||
valid_tools = []
|
valid_tools = []
|
||||||
valid_tools.append(tool)
|
valid_tools.append(tool)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from modules import shared
|
from modules import loaders, shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.LoRA import add_lora_to_model
|
from modules.LoRA import add_lora_to_model
|
||||||
from modules.models import load_model, unload_model
|
from modules.models import load_model, unload_model
|
||||||
|
|
@ -20,10 +20,14 @@ def list_models():
|
||||||
|
|
||||||
def list_models_openai_format():
|
def list_models_openai_format():
|
||||||
"""Returns model list in OpenAI API format"""
|
"""Returns model list in OpenAI API format"""
|
||||||
model_names = get_available_models()
|
if shared.model_name and shared.model_name != 'None':
|
||||||
|
data = [model_info_dict(shared.model_name)]
|
||||||
|
else:
|
||||||
|
data = []
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": [model_info_dict(name) for name in model_names]
|
"data": data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -46,9 +50,14 @@ def _load_model(data):
|
||||||
update_model_parameters(model_settings)
|
update_model_parameters(model_settings)
|
||||||
|
|
||||||
# Update shared.args with custom model loading settings
|
# Update shared.args with custom model loading settings
|
||||||
|
# Security: only allow keys that correspond to model loading
|
||||||
|
# parameters exposed in the UI. Never allow security-sensitive
|
||||||
|
# flags like trust_remote_code or extra_flags to be set via the API.
|
||||||
|
blocked_keys = {'extra_flags'}
|
||||||
|
allowed_keys = set(loaders.list_model_elements()) - blocked_keys
|
||||||
if args:
|
if args:
|
||||||
for k in args:
|
for k in args:
|
||||||
if hasattr(shared.args, k):
|
if k in allowed_keys and hasattr(shared.args, k):
|
||||||
setattr(shared.args, k, args[k])
|
setattr(shared.args, k, args[k])
|
||||||
|
|
||||||
shared.model, shared.tokenizer = load_model(model_name)
|
shared.model, shared.tokenizer = load_model(model_name)
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
|
import threading
|
||||||
import traceback
|
import traceback
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
@ -20,11 +21,12 @@ import extensions.openai.completions as OAIcompletions
|
||||||
import extensions.openai.logits as OAIlogits
|
import extensions.openai.logits as OAIlogits
|
||||||
import extensions.openai.models as OAImodels
|
import extensions.openai.models as OAImodels
|
||||||
from extensions.openai.tokens import token_count, token_decode, token_encode
|
from extensions.openai.tokens import token_count, token_decode, token_encode
|
||||||
|
from extensions.openai.errors import OpenAIError
|
||||||
from extensions.openai.utils import _start_cloudflared
|
from extensions.openai.utils import _start_cloudflared
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.models import unload_model
|
from modules.models import unload_model
|
||||||
from modules.text_generation import stop_everything_event
|
from modules.text_generation import stop_everything_event # used by /v1/internal/stop-generation
|
||||||
|
|
||||||
from .typing import (
|
from .typing import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
|
@ -58,8 +60,13 @@ params = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
streaming_semaphore = asyncio.Semaphore(1)
|
async def _wait_for_disconnect(request: Request, stop_event: threading.Event):
|
||||||
image_generation_semaphore = asyncio.Semaphore(1)
|
"""Block until the client disconnects, then signal the stop_event."""
|
||||||
|
while True:
|
||||||
|
message = await request.receive()
|
||||||
|
if message["type"] == "http.disconnect":
|
||||||
|
stop_event.set()
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def verify_api_key(authorization: str = Header(None)) -> None:
|
def verify_api_key(authorization: str = Header(None)) -> None:
|
||||||
|
|
@ -88,6 +95,20 @@ app.add_middleware(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(OpenAIError)
|
||||||
|
async def openai_error_handler(request: Request, exc: OpenAIError):
|
||||||
|
error_type = "server_error" if exc.code >= 500 else "invalid_request_error"
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.code,
|
||||||
|
content={"error": {
|
||||||
|
"message": exc.message,
|
||||||
|
"type": error_type,
|
||||||
|
"param": getattr(exc, 'param', None),
|
||||||
|
"code": None
|
||||||
|
}}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def validate_host_header(request: Request, call_next):
|
async def validate_host_header(request: Request, call_next):
|
||||||
# Be strict about only approving access to localhost by default
|
# Be strict about only approving access to localhost by default
|
||||||
|
|
@ -113,29 +134,44 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
|
||||||
is_legacy = "/generate" in path
|
is_legacy = "/generate" in path
|
||||||
|
|
||||||
if request_data.stream:
|
if request_data.stream:
|
||||||
|
if (request_data.n or 1) > 1:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={"error": {"message": "n > 1 is not supported with streaming.", "type": "invalid_request_error", "param": "n", "code": None}}
|
||||||
|
)
|
||||||
|
|
||||||
|
stop_event = threading.Event()
|
||||||
|
|
||||||
async def generator():
|
async def generator():
|
||||||
async with streaming_semaphore:
|
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy, stop_event=stop_event)
|
||||||
try:
|
try:
|
||||||
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy)
|
|
||||||
async for resp in iterate_in_threadpool(response):
|
async for resp in iterate_in_threadpool(response):
|
||||||
disconnected = await request.is_disconnected()
|
disconnected = await request.is_disconnected()
|
||||||
if disconnected:
|
if disconnected:
|
||||||
break
|
break
|
||||||
|
|
||||||
yield {"data": json.dumps(resp)}
|
yield {"data": json.dumps(resp)}
|
||||||
finally:
|
|
||||||
stop_everything_event()
|
|
||||||
response.close()
|
|
||||||
return
|
|
||||||
|
|
||||||
return EventSourceResponse(generator()) # SSE streaming
|
yield {"data": "[DONE]"}
|
||||||
|
finally:
|
||||||
|
stop_event.set()
|
||||||
|
response.close()
|
||||||
|
|
||||||
|
return EventSourceResponse(generator(), sep="\n") # SSE streaming
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
stop_event = threading.Event()
|
||||||
|
monitor = asyncio.create_task(_wait_for_disconnect(request, stop_event))
|
||||||
|
try:
|
||||||
response = await asyncio.to_thread(
|
response = await asyncio.to_thread(
|
||||||
OAIcompletions.completions,
|
OAIcompletions.completions,
|
||||||
to_dict(request_data),
|
to_dict(request_data),
|
||||||
is_legacy=is_legacy
|
is_legacy=is_legacy,
|
||||||
|
stop_event=stop_event
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
stop_event.set()
|
||||||
|
monitor.cancel()
|
||||||
|
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
@ -146,29 +182,38 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
|
||||||
is_legacy = "/generate" in path
|
is_legacy = "/generate" in path
|
||||||
|
|
||||||
if request_data.stream:
|
if request_data.stream:
|
||||||
|
stop_event = threading.Event()
|
||||||
|
|
||||||
async def generator():
|
async def generator():
|
||||||
async with streaming_semaphore:
|
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy, stop_event=stop_event)
|
||||||
try:
|
try:
|
||||||
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy)
|
|
||||||
async for resp in iterate_in_threadpool(response):
|
async for resp in iterate_in_threadpool(response):
|
||||||
disconnected = await request.is_disconnected()
|
disconnected = await request.is_disconnected()
|
||||||
if disconnected:
|
if disconnected:
|
||||||
break
|
break
|
||||||
|
|
||||||
yield {"data": json.dumps(resp)}
|
yield {"data": json.dumps(resp)}
|
||||||
finally:
|
|
||||||
stop_everything_event()
|
|
||||||
response.close()
|
|
||||||
return
|
|
||||||
|
|
||||||
return EventSourceResponse(generator()) # SSE streaming
|
yield {"data": "[DONE]"}
|
||||||
|
finally:
|
||||||
|
stop_event.set()
|
||||||
|
response.close()
|
||||||
|
|
||||||
|
return EventSourceResponse(generator(), sep="\n") # SSE streaming
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
stop_event = threading.Event()
|
||||||
|
monitor = asyncio.create_task(_wait_for_disconnect(request, stop_event))
|
||||||
|
try:
|
||||||
response = await asyncio.to_thread(
|
response = await asyncio.to_thread(
|
||||||
OAIcompletions.chat_completions,
|
OAIcompletions.chat_completions,
|
||||||
to_dict(request_data),
|
to_dict(request_data),
|
||||||
is_legacy=is_legacy
|
is_legacy=is_legacy,
|
||||||
|
stop_event=stop_event
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
stop_event.set()
|
||||||
|
monitor.cancel()
|
||||||
|
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
@ -232,7 +277,6 @@ async def handle_audio_transcription(request: Request):
|
||||||
async def handle_image_generation(request_data: ImageGenerationRequest):
|
async def handle_image_generation(request_data: ImageGenerationRequest):
|
||||||
import extensions.openai.images as OAIimages
|
import extensions.openai.images as OAIimages
|
||||||
|
|
||||||
async with image_generation_semaphore:
|
|
||||||
response = await asyncio.to_thread(OAIimages.generations, request_data)
|
response = await asyncio.to_thread(OAIimages.generations, request_data)
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
@ -357,9 +401,9 @@ async def handle_load_model(request_data: LoadModelRequest):
|
||||||
try:
|
try:
|
||||||
OAImodels._load_model(to_dict(request_data))
|
OAImodels._load_model(to_dict(request_data))
|
||||||
return JSONResponse(content="OK")
|
return JSONResponse(content="OK")
|
||||||
except:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return HTTPException(status_code=400, detail="Failed to load the model.")
|
raise HTTPException(status_code=400, detail="Failed to load the model.")
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/internal/model/unload", dependencies=check_admin_key)
|
@app.post("/v1/internal/model/unload", dependencies=check_admin_key)
|
||||||
|
|
@ -378,9 +422,9 @@ async def handle_load_loras(request_data: LoadLorasRequest):
|
||||||
try:
|
try:
|
||||||
OAImodels.load_loras(request_data.lora_names)
|
OAImodels.load_loras(request_data.lora_names)
|
||||||
return JSONResponse(content="OK")
|
return JSONResponse(content="OK")
|
||||||
except:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return HTTPException(status_code=400, detail="Failed to apply the LoRA(s).")
|
raise HTTPException(status_code=400, detail="Failed to apply the LoRA(s).")
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/internal/lora/unload", dependencies=check_admin_key)
|
@app.post("/v1/internal/lora/unload", dependencies=check_admin_key)
|
||||||
|
|
@ -414,6 +458,9 @@ def run_server():
|
||||||
|
|
||||||
# In the server configuration:
|
# In the server configuration:
|
||||||
server_addrs = []
|
server_addrs = []
|
||||||
|
if shared.args.listen and shared.args.listen_host:
|
||||||
|
server_addrs.append(shared.args.listen_host)
|
||||||
|
else:
|
||||||
if os.environ.get('OPENEDAI_ENABLE_IPV6', shared.args.api_enable_ipv6):
|
if os.environ.get('OPENEDAI_ENABLE_IPV6', shared.args.api_enable_ipv6):
|
||||||
server_addrs.append('[::]' if shared.args.listen else '[::1]')
|
server_addrs.append('[::]' if shared.args.listen else '[::1]')
|
||||||
if not os.environ.get('OPENEDAI_DISABLE_IPV4', shared.args.api_disable_ipv4):
|
if not os.environ.get('OPENEDAI_DISABLE_IPV4', shared.args.api_disable_ipv4):
|
||||||
|
|
@ -428,11 +475,11 @@ def run_server():
|
||||||
port,
|
port,
|
||||||
shared.args.public_api_id,
|
shared.args.public_api_id,
|
||||||
max_attempts=3,
|
max_attempts=3,
|
||||||
on_start=lambda url: logger.info(f'OpenAI-compatible API URL:\n\n{url}\n')
|
on_start=lambda url: logger.info(f'OpenAI-compatible API URL:\n\n{url}/v1\n')
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
url_proto = 'https://' if (ssl_certfile and ssl_keyfile) else 'http://'
|
url_proto = 'https://' if (ssl_certfile and ssl_keyfile) else 'http://'
|
||||||
urls = [f'{url_proto}{addr}:{port}' for addr in server_addrs]
|
urls = [f'{url_proto}{addr}:{port}/v1' for addr in server_addrs]
|
||||||
if len(urls) > 1:
|
if len(urls) > 1:
|
||||||
logger.info('OpenAI-compatible API URLs:\n\n' + '\n'.join(urls) + '\n')
|
logger.info('OpenAI-compatible API URLs:\n\n' + '\n'.join(urls) + '\n')
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -1,57 +1,61 @@
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator, validator
|
from pydantic import BaseModel, ConfigDict, Field, model_validator, validator
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
|
||||||
class GenerationOptions(BaseModel):
|
class GenerationOptions(BaseModel):
|
||||||
preset: str | None = Field(default=None, description="The name of a file under text-generation-webui/user_data/presets (without the .yaml extension). The sampling parameters that get overwritten by this option are the keys in the default_preset() function in modules/presets.py.")
|
preset: str | None = Field(default=None, description="The name of a file under text-generation-webui/user_data/presets (without the .yaml extension). The sampling parameters that get overwritten by this option are the keys in the default_preset() function in modules/presets.py.")
|
||||||
dynatemp_low: float = 1
|
dynatemp_low: float = shared.args.dynatemp_low
|
||||||
dynatemp_high: float = 1
|
dynatemp_high: float = shared.args.dynatemp_high
|
||||||
dynatemp_exponent: float = 1
|
dynatemp_exponent: float = shared.args.dynatemp_exponent
|
||||||
smoothing_factor: float = 0
|
smoothing_factor: float = shared.args.smoothing_factor
|
||||||
smoothing_curve: float = 1
|
smoothing_curve: float = shared.args.smoothing_curve
|
||||||
min_p: float = 0
|
min_p: float = shared.args.min_p
|
||||||
top_k: int = 0
|
top_k: int = shared.args.top_k
|
||||||
typical_p: float = 1
|
typical_p: float = shared.args.typical_p
|
||||||
xtc_threshold: float = 0.1
|
xtc_threshold: float = shared.args.xtc_threshold
|
||||||
xtc_probability: float = 0
|
xtc_probability: float = shared.args.xtc_probability
|
||||||
epsilon_cutoff: float = 0
|
epsilon_cutoff: float = shared.args.epsilon_cutoff
|
||||||
eta_cutoff: float = 0
|
eta_cutoff: float = shared.args.eta_cutoff
|
||||||
tfs: float = 1
|
tfs: float = shared.args.tfs
|
||||||
top_a: float = 0
|
top_a: float = shared.args.top_a
|
||||||
top_n_sigma: float = 0
|
top_n_sigma: float = shared.args.top_n_sigma
|
||||||
dry_multiplier: float = 0
|
adaptive_target: float = shared.args.adaptive_target
|
||||||
dry_allowed_length: int = 2
|
adaptive_decay: float = shared.args.adaptive_decay
|
||||||
dry_base: float = 1.75
|
dry_multiplier: float = shared.args.dry_multiplier
|
||||||
repetition_penalty: float = 1
|
dry_allowed_length: int = shared.args.dry_allowed_length
|
||||||
encoder_repetition_penalty: float = 1
|
dry_base: float = shared.args.dry_base
|
||||||
no_repeat_ngram_size: int = 0
|
repetition_penalty: float = shared.args.repetition_penalty
|
||||||
repetition_penalty_range: int = 1024
|
encoder_repetition_penalty: float = shared.args.encoder_repetition_penalty
|
||||||
penalty_alpha: float = 0
|
no_repeat_ngram_size: int = shared.args.no_repeat_ngram_size
|
||||||
guidance_scale: float = 1
|
repetition_penalty_range: int = shared.args.repetition_penalty_range
|
||||||
mirostat_mode: int = 0
|
penalty_alpha: float = shared.args.penalty_alpha
|
||||||
mirostat_tau: float = 5
|
guidance_scale: float = shared.args.guidance_scale
|
||||||
mirostat_eta: float = 0.1
|
mirostat_mode: int = shared.args.mirostat_mode
|
||||||
|
mirostat_tau: float = shared.args.mirostat_tau
|
||||||
|
mirostat_eta: float = shared.args.mirostat_eta
|
||||||
prompt_lookup_num_tokens: int = 0
|
prompt_lookup_num_tokens: int = 0
|
||||||
max_tokens_second: int = 0
|
max_tokens_second: int = 0
|
||||||
do_sample: bool = True
|
do_sample: bool = shared.args.do_sample
|
||||||
dynamic_temperature: bool = False
|
dynamic_temperature: bool = shared.args.dynamic_temperature
|
||||||
temperature_last: bool = False
|
temperature_last: bool = shared.args.temperature_last
|
||||||
auto_max_new_tokens: bool = False
|
auto_max_new_tokens: bool = False
|
||||||
ban_eos_token: bool = False
|
ban_eos_token: bool = False
|
||||||
add_bos_token: bool = True
|
add_bos_token: bool = True
|
||||||
enable_thinking: bool = True
|
enable_thinking: bool = shared.args.enable_thinking
|
||||||
reasoning_effort: str = "medium"
|
reasoning_effort: str = shared.args.reasoning_effort
|
||||||
skip_special_tokens: bool = True
|
skip_special_tokens: bool = True
|
||||||
static_cache: bool = False
|
static_cache: bool = False
|
||||||
truncation_length: int = 0
|
truncation_length: int = 0
|
||||||
seed: int = -1
|
seed: int = -1
|
||||||
sampler_priority: List[str] | str | None = Field(default=None, description="List of samplers where the first items will appear first in the stack. Example: [\"top_k\", \"temperature\", \"top_p\"].")
|
sampler_priority: List[str] | str | None = Field(default=shared.args.sampler_priority, description="List of samplers where the first items will appear first in the stack. Example: [\"top_k\", \"temperature\", \"top_p\"].")
|
||||||
custom_token_bans: str = ""
|
custom_token_bans: str = ""
|
||||||
negative_prompt: str = ''
|
negative_prompt: str = ''
|
||||||
dry_sequence_breakers: str = '"\\n", ":", "\\"", "*"'
|
dry_sequence_breakers: str = shared.args.dry_sequence_breakers
|
||||||
grammar_string: str = ""
|
grammar_string: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -61,22 +65,20 @@ class ToolDefinition(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ToolFunction(BaseModel):
|
class ToolFunction(BaseModel):
|
||||||
description: str
|
model_config = ConfigDict(extra='allow')
|
||||||
|
description: Optional[str] = None
|
||||||
name: str
|
name: str
|
||||||
parameters: 'ToolParameters'
|
parameters: Optional['ToolParameters'] = None
|
||||||
|
|
||||||
|
|
||||||
class ToolParameters(BaseModel):
|
class ToolParameters(BaseModel):
|
||||||
properties: Optional[Dict[str, 'ToolProperty']] = None
|
model_config = ConfigDict(extra='allow')
|
||||||
|
properties: Optional[Dict[str, Any]] = None
|
||||||
required: Optional[list[str]] = None
|
required: Optional[list[str]] = None
|
||||||
type: str
|
type: str
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class ToolProperty(BaseModel):
|
|
||||||
description: Optional[str] = None
|
|
||||||
type: Optional[str] = None # we are faced with definitions like anyOf, e.g. {'type': 'function', 'function': {'name': 'git_create_branch', 'description': 'Creates a new branch from an optional base branch', 'parameters': {'type': 'object', 'properties': {'repo_path': {'title': 'Repo Path', 'type': 'string'}, 'branch_name': {'title': 'Branch Name', 'type': 'string'}, 'base_branch': {'anyOf': [{'type': 'string'}, {'type': 'null'}], 'default': None, 'title': 'Base Branch'}}, 'required': ['repo_path', 'branch_name'], 'title': 'GitCreateBranch'}}}
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionCall(BaseModel):
|
class FunctionCall(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
|
@ -97,23 +99,28 @@ class ToolCall(BaseModel):
|
||||||
function: FunctionCall
|
function: FunctionCall
|
||||||
|
|
||||||
|
|
||||||
|
class StreamOptions(BaseModel):
|
||||||
|
include_usage: bool | None = False
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequestParams(BaseModel):
|
class CompletionRequestParams(BaseModel):
|
||||||
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
|
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
|
||||||
prompt: str | List[str] | None = Field(default=None, description="Text prompt for completion. Can also use 'messages' format for multimodal.")
|
prompt: str | List[str] | None = Field(default=None, description="Text prompt for completion. Can also use 'messages' format for multimodal.")
|
||||||
messages: List[dict] | None = Field(default=None, description="OpenAI messages format for multimodal support. Alternative to 'prompt'.")
|
messages: List[dict] | None = Field(default=None, description="OpenAI messages format for multimodal support. Alternative to 'prompt'.")
|
||||||
best_of: int | None = Field(default=1, description="Unused parameter.")
|
best_of: int | None = Field(default=1, description="Unused parameter.")
|
||||||
echo: bool | None = False
|
echo: bool | None = False
|
||||||
frequency_penalty: float | None = 0
|
frequency_penalty: float | None = shared.args.frequency_penalty
|
||||||
logit_bias: dict | None = None
|
logit_bias: dict | None = None
|
||||||
logprobs: int | None = None
|
logprobs: int | None = None
|
||||||
max_tokens: int | None = 512
|
max_tokens: int | None = 512
|
||||||
n: int | None = Field(default=1, description="Unused parameter.")
|
n: int | None = Field(default=1, description="Number of completions to generate. Only supported without streaming.")
|
||||||
presence_penalty: float | None = 0
|
presence_penalty: float | None = shared.args.presence_penalty
|
||||||
stop: str | List[str] | None = None
|
stop: str | List[str] | None = None
|
||||||
stream: bool | None = False
|
stream: bool | None = False
|
||||||
|
stream_options: StreamOptions | None = None
|
||||||
suffix: str | None = None
|
suffix: str | None = None
|
||||||
temperature: float | None = 1
|
temperature: float | None = shared.args.temperature
|
||||||
top_p: float | None = 1
|
top_p: float | None = shared.args.top_p
|
||||||
user: str | None = Field(default=None, description="Unused parameter.")
|
user: str | None = Field(default=None, description="Unused parameter.")
|
||||||
|
|
||||||
@model_validator(mode='after')
|
@model_validator(mode='after')
|
||||||
|
|
@ -139,20 +146,31 @@ class CompletionResponse(BaseModel):
|
||||||
class ChatCompletionRequestParams(BaseModel):
|
class ChatCompletionRequestParams(BaseModel):
|
||||||
messages: List[dict]
|
messages: List[dict]
|
||||||
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
|
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
|
||||||
frequency_penalty: float | None = 0
|
frequency_penalty: float | None = shared.args.frequency_penalty
|
||||||
function_call: str | dict | None = Field(default=None, description="Unused parameter.")
|
function_call: str | dict | None = Field(default=None, description="Unused parameter.")
|
||||||
functions: List[dict] | None = Field(default=None, description="Unused parameter.")
|
functions: List[dict] | None = Field(default=None, description="Unused parameter.")
|
||||||
tools: List[dict] | None = Field(default=None, description="Tools signatures passed via MCP.")
|
tools: List[dict] | None = Field(default=None, description="Tools signatures passed via MCP.")
|
||||||
|
tool_choice: str | dict | None = Field(default=None, description="Controls tool use: 'auto', 'none', 'required', or {\"type\": \"function\", \"function\": {\"name\": \"...\"}}.")
|
||||||
logit_bias: dict | None = None
|
logit_bias: dict | None = None
|
||||||
|
logprobs: bool | None = None
|
||||||
|
top_logprobs: int | None = None
|
||||||
max_tokens: int | None = None
|
max_tokens: int | None = None
|
||||||
|
max_completion_tokens: int | None = None
|
||||||
n: int | None = Field(default=1, description="Unused parameter.")
|
n: int | None = Field(default=1, description="Unused parameter.")
|
||||||
presence_penalty: float | None = 0
|
presence_penalty: float | None = shared.args.presence_penalty
|
||||||
stop: str | List[str] | None = None
|
stop: str | List[str] | None = None
|
||||||
stream: bool | None = False
|
stream: bool | None = False
|
||||||
temperature: float | None = 1
|
stream_options: StreamOptions | None = None
|
||||||
top_p: float | None = 1
|
temperature: float | None = shared.args.temperature
|
||||||
|
top_p: float | None = shared.args.top_p
|
||||||
user: str | None = Field(default=None, description="Unused parameter.")
|
user: str | None = Field(default=None, description="Unused parameter.")
|
||||||
|
|
||||||
|
@model_validator(mode='after')
|
||||||
|
def resolve_max_tokens(self):
|
||||||
|
if self.max_tokens is None and self.max_completion_tokens is not None:
|
||||||
|
self.max_tokens = self.max_completion_tokens
|
||||||
|
return self
|
||||||
|
|
||||||
mode: str = Field(default='instruct', description="Valid options: instruct, chat, chat-instruct.")
|
mode: str = Field(default='instruct', description="Valid options: instruct, chat, chat-instruct.")
|
||||||
|
|
||||||
instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/user_data/instruction-templates. If not set, the correct template will be automatically obtained from the model metadata.")
|
instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/user_data/instruction-templates. If not set, the correct template will be automatically obtained from the model metadata.")
|
||||||
|
|
@ -226,11 +244,11 @@ class LogitsRequestParams(BaseModel):
|
||||||
prompt: str
|
prompt: str
|
||||||
use_samplers: bool = False
|
use_samplers: bool = False
|
||||||
top_logits: int | None = 50
|
top_logits: int | None = 50
|
||||||
frequency_penalty: float | None = 0
|
frequency_penalty: float | None = shared.args.frequency_penalty
|
||||||
max_tokens: int | None = 512
|
max_tokens: int | None = 512
|
||||||
presence_penalty: float | None = 0
|
presence_penalty: float | None = shared.args.presence_penalty
|
||||||
temperature: float | None = 1
|
temperature: float | None = shared.args.temperature
|
||||||
top_p: float | None = 1
|
top_p: float | None = shared.args.top_p
|
||||||
|
|
||||||
|
|
||||||
class LogitsRequest(GenerationOptions, LogitsRequestParams):
|
class LogitsRequest(GenerationOptions, LogitsRequestParams):
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,5 @@
|
||||||
import base64
|
import base64
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import re
|
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
@ -55,94 +52,3 @@ def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_star
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
|
|
||||||
raise Exception('Could not start cloudflared.')
|
raise Exception('Could not start cloudflared.')
|
||||||
|
|
||||||
|
|
||||||
def getToolCallId() -> str:
|
|
||||||
letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
|
||||||
b = [random.choice(letter_bytes) for _ in range(8)]
|
|
||||||
return "call_" + "".join(b).lower()
|
|
||||||
|
|
||||||
|
|
||||||
def checkAndSanitizeToolCallCandidate(candidate_dict: dict, tool_names: list[str]):
|
|
||||||
# check if property 'function' exists and is a dictionary, otherwise adapt dict
|
|
||||||
if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str):
|
|
||||||
candidate_dict = {"type": "function", "function": candidate_dict}
|
|
||||||
if 'function' in candidate_dict and isinstance(candidate_dict['function'], str):
|
|
||||||
candidate_dict['name'] = candidate_dict['function']
|
|
||||||
del candidate_dict['function']
|
|
||||||
candidate_dict = {"type": "function", "function": candidate_dict}
|
|
||||||
if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict):
|
|
||||||
# check if 'name' exists within 'function' and is part of known tools
|
|
||||||
if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names:
|
|
||||||
candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value
|
|
||||||
# map property 'parameters' used by some older models to 'arguments'
|
|
||||||
if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]:
|
|
||||||
candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"]
|
|
||||||
del candidate_dict["function"]["parameters"]
|
|
||||||
return candidate_dict
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def parseToolCall(answer: str, tool_names: list[str]):
|
|
||||||
matches = []
|
|
||||||
|
|
||||||
# abort on very short answers to save computation cycles
|
|
||||||
if len(answer) < 10:
|
|
||||||
return matches
|
|
||||||
|
|
||||||
# Define the regex pattern to find the JSON content wrapped in <function>, <tools>, <tool_call>, and other tags observed from various models
|
|
||||||
patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)</\1>"]
|
|
||||||
|
|
||||||
for pattern in patterns:
|
|
||||||
for match in re.finditer(pattern, answer, re.DOTALL):
|
|
||||||
# print(match.group(2))
|
|
||||||
if match.group(2) is None:
|
|
||||||
continue
|
|
||||||
# remove backtick wraps if present
|
|
||||||
candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip())
|
|
||||||
candidate = re.sub(r"```$", "", candidate.strip())
|
|
||||||
# unwrap inner tags
|
|
||||||
candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL)
|
|
||||||
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
|
|
||||||
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
|
|
||||||
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
|
|
||||||
if not candidate.strip().startswith("["):
|
|
||||||
candidate = "[" + candidate + "]"
|
|
||||||
|
|
||||||
candidates = []
|
|
||||||
try:
|
|
||||||
# parse the candidate JSON into a dictionary
|
|
||||||
candidates = json.loads(candidate)
|
|
||||||
if not isinstance(candidates, list):
|
|
||||||
candidates = [candidates]
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
# Ignore invalid JSON silently
|
|
||||||
continue
|
|
||||||
|
|
||||||
for candidate_dict in candidates:
|
|
||||||
checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names)
|
|
||||||
if checked_candidate is not None:
|
|
||||||
matches.append(checked_candidate)
|
|
||||||
|
|
||||||
# last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags
|
|
||||||
if len(matches) == 0:
|
|
||||||
try:
|
|
||||||
candidate = answer
|
|
||||||
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
|
|
||||||
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
|
|
||||||
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
|
|
||||||
if not candidate.strip().startswith("["):
|
|
||||||
candidate = "[" + candidate + "]"
|
|
||||||
# parse the candidate JSON into a dictionary
|
|
||||||
candidates = json.loads(candidate)
|
|
||||||
if not isinstance(candidates, list):
|
|
||||||
candidates = [candidates]
|
|
||||||
for candidate_dict in candidates:
|
|
||||||
checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names)
|
|
||||||
if checked_candidate is not None:
|
|
||||||
matches.append(checked_candidate)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
# Ignore invalid JSON silently
|
|
||||||
pass
|
|
||||||
|
|
||||||
return matches
|
|
||||||
|
|
|
||||||
|
|
@ -264,7 +264,7 @@ def SD_api_address_update(address):
|
||||||
response = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models')
|
response = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models')
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
# r = response.json()
|
# r = response.json()
|
||||||
except:
|
except Exception:
|
||||||
msg = "❌ No SD API endpoint on:"
|
msg = "❌ No SD API endpoint on:"
|
||||||
|
|
||||||
return gr.Textbox.update(label=msg)
|
return gr.Textbox.update(label=msg)
|
||||||
|
|
@ -284,7 +284,7 @@ def get_checkpoints():
|
||||||
options_json = options.json()
|
options_json = options.json()
|
||||||
params['sd_checkpoint'] = options_json['sd_model_checkpoint']
|
params['sd_checkpoint'] = options_json['sd_model_checkpoint']
|
||||||
params['checkpoint_list'] = [result["title"] for result in models.json()]
|
params['checkpoint_list'] = [result["title"] for result in models.json()]
|
||||||
except:
|
except Exception:
|
||||||
params['sd_checkpoint'] = ""
|
params['sd_checkpoint'] = ""
|
||||||
params['checkpoint_list'] = []
|
params['checkpoint_list'] = []
|
||||||
|
|
||||||
|
|
@ -298,7 +298,7 @@ def load_checkpoint(checkpoint):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
requests.post(url=f'{params["address"]}/sdapi/v1/options', json=payload)
|
requests.post(url=f'{params["address"]}/sdapi/v1/options', json=payload)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -307,7 +307,7 @@ def get_samplers():
|
||||||
response = requests.get(url=f'{params["address"]}/sdapi/v1/samplers')
|
response = requests.get(url=f'{params["address"]}/sdapi/v1/samplers')
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
samplers = [x["name"] for x in response.json()]
|
samplers = [x["name"] for x in response.json()]
|
||||||
except:
|
except Exception:
|
||||||
samplers = []
|
samplers = []
|
||||||
|
|
||||||
return samplers
|
return samplers
|
||||||
|
|
|
||||||
|
|
@ -2,5 +2,5 @@ beautifulsoup4==4.12.2
|
||||||
chromadb==0.4.24
|
chromadb==0.4.24
|
||||||
pandas==2.0.3
|
pandas==2.0.3
|
||||||
posthog==2.4.2
|
posthog==2.4.2
|
||||||
sentence_transformers==2.2.2
|
sentence_transformers==3.3.1
|
||||||
lxml
|
lxml
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,11 @@ function copyToClipboard(element) {
|
||||||
const rawText = messageElement.getAttribute("data-raw");
|
const rawText = messageElement.getAttribute("data-raw");
|
||||||
if (!rawText) return;
|
if (!rawText) return;
|
||||||
|
|
||||||
navigator.clipboard.writeText(rawText).then(function() {
|
const copyPromise = navigator.clipboard && window.isSecureContext
|
||||||
|
? navigator.clipboard.writeText(rawText)
|
||||||
|
: fallbackCopyToClipboard(rawText);
|
||||||
|
|
||||||
|
copyPromise.then(function() {
|
||||||
const originalSvg = element.innerHTML;
|
const originalSvg = element.innerHTML;
|
||||||
element.innerHTML = "<svg xmlns=\"http://www.w3.org/2000/svg\" width=\"20\" height=\"20\" viewBox=\"0 0 24 24\" fill=\"none\" stroke=\"currentColor\" stroke-width=\"2\" stroke-linecap=\"round\" stroke-linejoin=\"round\" class=\"text-green-500 dark:text-green-400\"><path d=\"M5 12l5 5l10 -10\"></path></svg>";
|
element.innerHTML = "<svg xmlns=\"http://www.w3.org/2000/svg\" width=\"20\" height=\"20\" viewBox=\"0 0 24 24\" fill=\"none\" stroke=\"currentColor\" stroke-width=\"2\" stroke-linecap=\"round\" stroke-linejoin=\"round\" class=\"text-green-500 dark:text-green-400\"><path d=\"M5 12l5 5l10 -10\"></path></svg>";
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
|
|
@ -22,6 +26,27 @@ function copyToClipboard(element) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function fallbackCopyToClipboard(text) {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
const textArea = document.createElement("textarea");
|
||||||
|
textArea.value = text;
|
||||||
|
textArea.style.position = "fixed";
|
||||||
|
textArea.style.left = "-9999px";
|
||||||
|
textArea.style.top = "-9999px";
|
||||||
|
document.body.appendChild(textArea);
|
||||||
|
textArea.focus();
|
||||||
|
textArea.select();
|
||||||
|
try {
|
||||||
|
const successful = document.execCommand("copy");
|
||||||
|
document.body.removeChild(textArea);
|
||||||
|
successful ? resolve() : reject();
|
||||||
|
} catch (err) {
|
||||||
|
document.body.removeChild(textArea);
|
||||||
|
reject(err);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
function branchHere(element) {
|
function branchHere(element) {
|
||||||
if (!element) return;
|
if (!element) return;
|
||||||
|
|
||||||
|
|
@ -244,7 +269,49 @@ function removeLastClick() {
|
||||||
document.getElementById("Remove-last").click();
|
document.getElementById("Remove-last").click();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function autoScrollToBottom() {
|
||||||
|
if (!window.isScrolled) {
|
||||||
|
const chatParent = document.getElementById("chat")?.parentNode?.parentNode?.parentNode;
|
||||||
|
if (chatParent) {
|
||||||
|
const maxScroll = chatParent.scrollHeight - chatParent.clientHeight;
|
||||||
|
if (maxScroll > 0 && chatParent.scrollTop < maxScroll - 1) {
|
||||||
|
chatParent.scrollTop = maxScroll;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateInstructPadding() {
|
||||||
|
const chatElement = document.getElementById("chat");
|
||||||
|
if (chatElement && chatElement.getAttribute("data-mode") === "instruct") {
|
||||||
|
const messagesContainer = chatElement.querySelector(".messages");
|
||||||
|
const lastChild = messagesContainer?.lastElementChild;
|
||||||
|
const prevSibling = lastChild?.previousElementSibling;
|
||||||
|
if (lastChild && prevSibling && chatElement.offsetHeight > 0) {
|
||||||
|
let bufferHeight = Math.max(0, Math.max(window.innerHeight - 128 - 84, window.innerHeight - prevSibling.offsetHeight - 84) - lastChild.offsetHeight);
|
||||||
|
if (window.innerWidth <= 924) {
|
||||||
|
bufferHeight = Math.max(0, bufferHeight - 32);
|
||||||
|
}
|
||||||
|
messagesContainer.style.paddingBottom = `${bufferHeight}px`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let pendingMorphdomData = null;
|
||||||
|
let morphdomRafId = null;
|
||||||
|
|
||||||
function handleMorphdomUpdate(data) {
|
function handleMorphdomUpdate(data) {
|
||||||
|
pendingMorphdomData = data;
|
||||||
|
if (!morphdomRafId) {
|
||||||
|
morphdomRafId = requestAnimationFrame(() => {
|
||||||
|
morphdomRafId = null;
|
||||||
|
applyMorphdomUpdate(pendingMorphdomData);
|
||||||
|
pendingMorphdomData = null;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function applyMorphdomUpdate(data) {
|
||||||
// Determine target element and use it as query scope
|
// Determine target element and use it as query scope
|
||||||
var target_element, target_html;
|
var target_element, target_html;
|
||||||
if (data.last_message_only) {
|
if (data.last_message_only) {
|
||||||
|
|
@ -258,28 +325,22 @@ function handleMorphdomUpdate(data) {
|
||||||
|
|
||||||
const queryScope = target_element;
|
const queryScope = target_element;
|
||||||
|
|
||||||
// Track open blocks
|
// Track open blocks and store their scroll positions
|
||||||
const openBlocks = new Set();
|
const openBlocks = new Set();
|
||||||
|
const scrollPositions = {};
|
||||||
queryScope.querySelectorAll(".thinking-block").forEach(block => {
|
queryScope.querySelectorAll(".thinking-block").forEach(block => {
|
||||||
const blockId = block.getAttribute("data-block-id");
|
const blockId = block.getAttribute("data-block-id");
|
||||||
// If block exists and is open, add to open set
|
|
||||||
if (blockId && block.hasAttribute("open")) {
|
if (blockId && block.hasAttribute("open")) {
|
||||||
openBlocks.add(blockId);
|
openBlocks.add(blockId);
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Store scroll positions for any open blocks
|
|
||||||
const scrollPositions = {};
|
|
||||||
queryScope.querySelectorAll(".thinking-block[open]").forEach(block => {
|
|
||||||
const content = block.querySelector(".thinking-content");
|
const content = block.querySelector(".thinking-content");
|
||||||
const blockId = block.getAttribute("data-block-id");
|
if (content) {
|
||||||
if (content && blockId) {
|
|
||||||
const isAtBottom = Math.abs((content.scrollHeight - content.scrollTop) - content.clientHeight) < 5;
|
const isAtBottom = Math.abs((content.scrollHeight - content.scrollTop) - content.clientHeight) < 5;
|
||||||
scrollPositions[blockId] = {
|
scrollPositions[blockId] = {
|
||||||
position: content.scrollTop,
|
position: content.scrollTop,
|
||||||
isAtBottom: isAtBottom
|
isAtBottom: isAtBottom
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
morphdom(
|
morphdom(
|
||||||
|
|
@ -288,8 +349,8 @@ function handleMorphdomUpdate(data) {
|
||||||
{
|
{
|
||||||
onBeforeElUpdated: function(fromEl, toEl) {
|
onBeforeElUpdated: function(fromEl, toEl) {
|
||||||
// Preserve code highlighting
|
// Preserve code highlighting
|
||||||
if (fromEl.tagName === "PRE" && fromEl.querySelector("code[data-highlighted]")) {
|
if (fromEl.tagName === "PRE") {
|
||||||
const fromCode = fromEl.querySelector("code");
|
const fromCode = fromEl.querySelector("code[data-highlighted]");
|
||||||
const toCode = toEl.querySelector("code");
|
const toCode = toEl.querySelector("code");
|
||||||
|
|
||||||
if (fromCode && toCode && fromCode.textContent === toCode.textContent) {
|
if (fromCode && toCode && fromCode.textContent === toCode.textContent) {
|
||||||
|
|
@ -334,10 +395,23 @@ function handleMorphdomUpdate(data) {
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Syntax highlighting and LaTeX
|
||||||
|
if (window.doSyntaxHighlighting) {
|
||||||
|
window.doSyntaxHighlighting();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auto-scroll runs both before and after padding update.
|
||||||
|
// Before: so content growth isn't hidden by padding absorption.
|
||||||
|
// After: so padding-added space is also scrolled into view.
|
||||||
|
autoScrollToBottom();
|
||||||
|
updateInstructPadding();
|
||||||
|
autoScrollToBottom();
|
||||||
|
|
||||||
// Add toggle listeners for new blocks
|
// Add toggle listeners for new blocks
|
||||||
queryScope.querySelectorAll(".thinking-block").forEach(block => {
|
queryScope.querySelectorAll(".thinking-block").forEach(block => {
|
||||||
if (!block._hasToggleListener) {
|
if (!block._hasToggleListener) {
|
||||||
block.addEventListener("toggle", function(e) {
|
block.addEventListener("toggle", function(e) {
|
||||||
|
const wasScrolled = window.isScrolled;
|
||||||
if (this.open) {
|
if (this.open) {
|
||||||
const content = this.querySelector(".thinking-content");
|
const content = this.querySelector(".thinking-content");
|
||||||
if (content) {
|
if (content) {
|
||||||
|
|
@ -346,44 +420,14 @@ function handleMorphdomUpdate(data) {
|
||||||
}, 0);
|
}, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
autoScrollToBottom();
|
||||||
|
updateInstructPadding();
|
||||||
|
autoScrollToBottom();
|
||||||
|
// Restore scroll state so the browser's layout adjustment
|
||||||
|
// from the toggle doesn't disable auto-scroll
|
||||||
|
window.isScrolled = wasScrolled;
|
||||||
});
|
});
|
||||||
block._hasToggleListener = true;
|
block._hasToggleListener = true;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for Gradio to finish setting its styles, then force dark theme
|
|
||||||
const observer = new MutationObserver((mutations) => {
|
|
||||||
mutations.forEach((mutation) => {
|
|
||||||
if (mutation.type === "attributes" &&
|
|
||||||
mutation.target.tagName === "GRADIO-APP" &&
|
|
||||||
mutation.attributeName === "style") {
|
|
||||||
|
|
||||||
// Gradio just set its styles, now force dark theme
|
|
||||||
document.body.classList.add("dark");
|
|
||||||
observer.disconnect();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
// Start observing
|
|
||||||
observer.observe(document.documentElement, {
|
|
||||||
attributes: true,
|
|
||||||
subtree: true,
|
|
||||||
attributeFilter: ["style"]
|
|
||||||
});
|
|
||||||
|
|
||||||
//------------------------------------------------
|
|
||||||
// Suppress "Attempted to select a non-interactive or hidden tab" warning
|
|
||||||
//------------------------------------------------
|
|
||||||
(function() {
|
|
||||||
const originalWarn = console.warn;
|
|
||||||
|
|
||||||
console.warn = function(...args) {
|
|
||||||
if (args[0] && typeof args[0] === "string" && args[0].includes("Attempted to select a non-interactive or hidden tab")) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
originalWarn.apply(console, args);
|
|
||||||
};
|
|
||||||
})();
|
|
||||||
|
|
|
||||||
85
js/highlightjs/highlightjs-copy.min.js
vendored
85
js/highlightjs/highlightjs-copy.min.js
vendored
|
|
@ -1 +1,84 @@
|
||||||
class CopyButtonPlugin{constructor(options={}){self.hook=options.hook;self.callback=options.callback;self.lang=options.lang||document.documentElement.lang||"en"}"after:highlightElement"({el,text}){let button=Object.assign(document.createElement("button"),{innerHTML:locales[lang]?.[0]||"Copy",className:"hljs-copy-button"});button.dataset.copied=false;el.parentElement.classList.add("hljs-copy-wrapper");el.parentElement.appendChild(button);el.parentElement.style.setProperty("--hljs-theme-background",window.getComputedStyle(el).backgroundColor);button.onclick=function(){if(!navigator.clipboard)return;let newText=text;if(hook&&typeof hook==="function"){newText=hook(text,el)||text}navigator.clipboard.writeText(newText).then(function(){button.innerHTML=locales[lang]?.[1]||"Copied!";button.dataset.copied=true;let alert=Object.assign(document.createElement("div"),{role:"status",className:"hljs-copy-alert",innerHTML:locales[lang]?.[2]||"Copied to clipboard"});el.parentElement.appendChild(alert);setTimeout(()=>{button.innerHTML=locales[lang]?.[0]||"Copy";button.dataset.copied=false;el.parentElement.removeChild(alert);alert=null},2e3)}).then(function(){if(typeof callback==="function")return callback(newText,el)})}}}if(typeof module!="undefined"){module.exports=CopyButtonPlugin}const locales={en:["Copy","Copied!","Copied to clipboard"],es:["Copiar","¡Copiado!","Copiado al portapapeles"],fr:["Copier","Copié !","Copié dans le presse-papier"],de:["Kopieren","Kopiert!","In die Zwischenablage kopiert"],ja:["コピー","コピーしました!","クリップボードにコピーしました"],ko:["복사","복사됨!","클립보드에 복사됨"],ru:["Копировать","Скопировано!","Скопировано в буфер обмена"],zh:["复制","已复制!","已复制到剪贴板"],"zh-tw":["複製","已複製!","已複製到剪貼簿"]};
|
function fallbackCopyToClipboard(text) {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
const textArea = document.createElement("textarea");
|
||||||
|
textArea.value = text;
|
||||||
|
textArea.style.position = "fixed";
|
||||||
|
textArea.style.left = "-9999px";
|
||||||
|
textArea.style.top = "-9999px";
|
||||||
|
document.body.appendChild(textArea);
|
||||||
|
textArea.focus();
|
||||||
|
textArea.select();
|
||||||
|
try {
|
||||||
|
const successful = document.execCommand("copy");
|
||||||
|
document.body.removeChild(textArea);
|
||||||
|
successful ? resolve() : reject();
|
||||||
|
} catch (err) {
|
||||||
|
document.body.removeChild(textArea);
|
||||||
|
reject(err);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
class CopyButtonPlugin {
|
||||||
|
constructor(options = {}) {
|
||||||
|
self.hook = options.hook;
|
||||||
|
self.callback = options.callback;
|
||||||
|
self.lang = options.lang || document.documentElement.lang || "en";
|
||||||
|
}
|
||||||
|
"after:highlightElement"({ el, text }) {
|
||||||
|
let button = Object.assign(document.createElement("button"), {
|
||||||
|
innerHTML: locales[lang]?.[0] || "Copy",
|
||||||
|
className: "hljs-copy-button",
|
||||||
|
});
|
||||||
|
button.dataset.copied = false;
|
||||||
|
el.parentElement.classList.add("hljs-copy-wrapper");
|
||||||
|
el.parentElement.appendChild(button);
|
||||||
|
el.parentElement.style.setProperty(
|
||||||
|
"--hljs-theme-background",
|
||||||
|
window.getComputedStyle(el).backgroundColor,
|
||||||
|
);
|
||||||
|
button.onclick = function () {
|
||||||
|
let newText = text;
|
||||||
|
if (hook && typeof hook === "function") {
|
||||||
|
newText = hook(text, el) || text;
|
||||||
|
}
|
||||||
|
const copyPromise =
|
||||||
|
navigator.clipboard && window.isSecureContext
|
||||||
|
? navigator.clipboard.writeText(newText)
|
||||||
|
: fallbackCopyToClipboard(newText);
|
||||||
|
copyPromise.then(function () {
|
||||||
|
button.innerHTML = locales[lang]?.[1] || "Copied!";
|
||||||
|
button.dataset.copied = true;
|
||||||
|
let alert = Object.assign(document.createElement("div"), {
|
||||||
|
role: "status",
|
||||||
|
className: "hljs-copy-alert",
|
||||||
|
innerHTML: locales[lang]?.[2] || "Copied to clipboard",
|
||||||
|
});
|
||||||
|
el.parentElement.appendChild(alert);
|
||||||
|
setTimeout(() => {
|
||||||
|
button.innerHTML = locales[lang]?.[0] || "Copy";
|
||||||
|
button.dataset.copied = false;
|
||||||
|
el.parentElement.removeChild(alert);
|
||||||
|
alert = null;
|
||||||
|
}, 2e3);
|
||||||
|
})
|
||||||
|
.then(function () {
|
||||||
|
if (typeof callback === "function") return callback(newText, el);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (typeof module != "undefined") {
|
||||||
|
module.exports = CopyButtonPlugin;
|
||||||
|
}
|
||||||
|
const locales = {
|
||||||
|
en: ["Copy", "Copied!", "Copied to clipboard"],
|
||||||
|
es: ["Copiar", "¡Copiado!", "Copiado al portapapeles"],
|
||||||
|
fr: ["Copier", "Copié !", "Copié dans le presse-papier"],
|
||||||
|
de: ["Kopieren", "Kopiert!", "In die Zwischenablage kopiert"],
|
||||||
|
ja: ["コピー", "コピーしました!", "クリップボードにコピーしました"],
|
||||||
|
ko: ["복사", "복사됨!", "클립보드에 복사됨"],
|
||||||
|
ru: ["Копировать", "Скопировано!", "Скопировано в буфер обмена"],
|
||||||
|
zh: ["复制", "已复制!", "已复制到剪贴板"],
|
||||||
|
"zh-tw": ["複製", "已複製!", "已複製到剪貼簿"],
|
||||||
|
};
|
||||||
|
|
|
||||||
184
js/katex/auto-render.js
Normal file
184
js/katex/auto-render.js
Normal file
|
|
@ -0,0 +1,184 @@
|
||||||
|
! function(e, t) {
|
||||||
|
"object" == typeof exports && "object" == typeof module ? module.exports = t(require("katex")) : "function" == typeof define && define.amd ? define(["katex"], t) : "object" == typeof exports ? exports.renderMathInElement = t(require("katex")) : e.renderMathInElement = t(e.katex)
|
||||||
|
}("undefined" != typeof self ? self : this, (function(e) {
|
||||||
|
return function() {
|
||||||
|
"use strict";
|
||||||
|
var t = {
|
||||||
|
771: function(t) {
|
||||||
|
t.exports = e
|
||||||
|
}
|
||||||
|
},
|
||||||
|
n = {};
|
||||||
|
|
||||||
|
function r(e) {
|
||||||
|
var o = n[e];
|
||||||
|
if (void 0 !== o) return o.exports;
|
||||||
|
var i = n[e] = {
|
||||||
|
exports: {}
|
||||||
|
};
|
||||||
|
return t[e](i, i.exports, r), i.exports
|
||||||
|
}
|
||||||
|
r.n = function(e) {
|
||||||
|
var t = e && e.__esModule ? function() {
|
||||||
|
return e.default
|
||||||
|
} : function() {
|
||||||
|
return e
|
||||||
|
};
|
||||||
|
return r.d(t, {
|
||||||
|
a: t
|
||||||
|
}), t
|
||||||
|
}, r.d = function(e, t) {
|
||||||
|
for (var n in t) r.o(t, n) && !r.o(e, n) && Object.defineProperty(e, n, {
|
||||||
|
enumerable: !0,
|
||||||
|
get: t[n]
|
||||||
|
})
|
||||||
|
}, r.o = function(e, t) {
|
||||||
|
return Object.prototype.hasOwnProperty.call(e, t)
|
||||||
|
};
|
||||||
|
var o = {};
|
||||||
|
return function() {
|
||||||
|
r.d(o, {
|
||||||
|
default: function() {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
});
|
||||||
|
var e = r(771),
|
||||||
|
t = r.n(e);
|
||||||
|
const n = function(e, t, n) {
|
||||||
|
let r = n,
|
||||||
|
o = 0;
|
||||||
|
const i = e.length;
|
||||||
|
for (; r < t.length;) {
|
||||||
|
const n = t[r];
|
||||||
|
if (o <= 0 && t.slice(r, r + i) === e) return r;
|
||||||
|
"\\" === n ? r++ : "{" === n ? o++ : "}" === n && o--, r++
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
},
|
||||||
|
i = /^\\begin{/;
|
||||||
|
var a = function(e, t) {
|
||||||
|
let r;
|
||||||
|
const o = [],
|
||||||
|
a = new RegExp("(" + t.map((e => e.left.replace(/[-/\\^$*+?.()|[\]{}]/g, "\\$&"))).join("|") + ")");
|
||||||
|
for (; r = e.search(a), -1 !== r;) {
|
||||||
|
const charAfterOpen = e[r + 1];
|
||||||
|
if (e[r] == "$" && charAfterOpen != "$") {
|
||||||
|
const closeDollarIndex = e.indexOf('$', r + 1);
|
||||||
|
if (closeDollarIndex != -1) {
|
||||||
|
const charBeforeOpen = r > 0 ? e[r - 1] : '';
|
||||||
|
const charBeforeClose = r + 1 < closeDollarIndex ? e[closeDollarIndex - 1] : '';
|
||||||
|
const charBeforeBeforeClose = r + 1 < closeDollarIndex ? e[closeDollarIndex - 2] : '';
|
||||||
|
const charAfterClose = closeDollarIndex + 1 < e.length ? e[closeDollarIndex + 1] : '';
|
||||||
|
if ((/[A-Za-z0-9_$-]/.test(charBeforeOpen)) || ((' ' == charBeforeClose) ||
|
||||||
|
/[0-9]/.test(charAfterOpen) &&
|
||||||
|
(/[A-Za-z0-9]/.test(charAfterClose)
|
||||||
|
|| '-' == charBeforeClose))) {
|
||||||
|
o.push({
|
||||||
|
type: "text",
|
||||||
|
data: e.slice(0, r + 1),
|
||||||
|
});
|
||||||
|
e = e.slice(r + 1); // now text starts after delimiter
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r > 0 && (o.push({
|
||||||
|
type: "text",
|
||||||
|
data: e.slice(0, r)
|
||||||
|
}), e = e.slice(r));
|
||||||
|
const a = t.findIndex((t => e.startsWith(t.left)));
|
||||||
|
if (r = n(t[a].right, e, t[a].left.length), -1 === r) break;
|
||||||
|
const l = e.slice(0, r + t[a].right.length),
|
||||||
|
s = i.test(l) ? l : e.slice(t[a].left.length, r);
|
||||||
|
o.push({
|
||||||
|
type: "math",
|
||||||
|
data: s,
|
||||||
|
rawData: l,
|
||||||
|
display: t[a].display
|
||||||
|
}), e = e.slice(r + t[a].right.length)
|
||||||
|
}
|
||||||
|
return "" !== e && o.push({
|
||||||
|
type: "text",
|
||||||
|
data: e
|
||||||
|
}), o
|
||||||
|
};
|
||||||
|
const l = function(e, n) {
|
||||||
|
const r = a(e, n.delimiters);
|
||||||
|
if (1 === r.length && "text" === r[0].type) return null;
|
||||||
|
const o = document.createDocumentFragment();
|
||||||
|
for (let e = 0; e < r.length; e++)
|
||||||
|
if ("text" === r[e].type) o.appendChild(document.createTextNode(r[e].data));
|
||||||
|
else {
|
||||||
|
const i = document.createElement("span");
|
||||||
|
let a = r[e].data;
|
||||||
|
n.displayMode = r[e].display;
|
||||||
|
try {
|
||||||
|
n.preProcess && (a = n.preProcess(a)), t().render(a, i, n)
|
||||||
|
} catch (i) {
|
||||||
|
if (!(i instanceof t().ParseError)) throw i;
|
||||||
|
n.errorCallback("KaTeX auto-render: Failed to parse `" + r[e].data + "` with ", i), o.appendChild(document.createTextNode(r[e].rawData));
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
o.appendChild(i)
|
||||||
|
}
|
||||||
|
return o
|
||||||
|
},
|
||||||
|
s = function(e, t) {
|
||||||
|
for (let n = 0; n < e.childNodes.length; n++) {
|
||||||
|
const r = e.childNodes[n];
|
||||||
|
if (3 === r.nodeType) {
|
||||||
|
let o = r.textContent,
|
||||||
|
i = r.nextSibling,
|
||||||
|
a = 0;
|
||||||
|
for (; i && i.nodeType === Node.TEXT_NODE;) o += i.textContent, i = i.nextSibling, a++;
|
||||||
|
const s = l(o, t);
|
||||||
|
if (s) {
|
||||||
|
for (let e = 0; e < a; e++) r.nextSibling.remove();
|
||||||
|
n += s.childNodes.length - 1, e.replaceChild(s, r)
|
||||||
|
} else n += a
|
||||||
|
} else if (1 === r.nodeType) {
|
||||||
|
const e = " " + r.className + " "; - 1 === t.ignoredTags.indexOf(r.nodeName.toLowerCase()) && t.ignoredClasses.every((t => -1 === e.indexOf(" " + t + " "))) && s(r, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
var d = function(e, t) {
|
||||||
|
if (!e) throw new Error("No element provided to render");
|
||||||
|
const n = {};
|
||||||
|
for (const e in t) t.hasOwnProperty(e) && (n[e] = t[e]);
|
||||||
|
n.delimiters = n.delimiters || [{
|
||||||
|
left: "$$",
|
||||||
|
right: "$$",
|
||||||
|
display: !0
|
||||||
|
}, {
|
||||||
|
left: "\\(",
|
||||||
|
right: "\\)",
|
||||||
|
display: !1
|
||||||
|
}, {
|
||||||
|
left: "\\begin{equation}",
|
||||||
|
right: "\\end{equation}",
|
||||||
|
display: !0
|
||||||
|
}, {
|
||||||
|
left: "\\begin{align}",
|
||||||
|
right: "\\end{align}",
|
||||||
|
display: !0
|
||||||
|
}, {
|
||||||
|
left: "\\begin{alignat}",
|
||||||
|
right: "\\end{alignat}",
|
||||||
|
display: !0
|
||||||
|
}, {
|
||||||
|
left: "\\begin{gather}",
|
||||||
|
right: "\\end{gather}",
|
||||||
|
display: !0
|
||||||
|
}, {
|
||||||
|
left: "\\begin{CD}",
|
||||||
|
right: "\\end{CD}",
|
||||||
|
display: !0
|
||||||
|
}, {
|
||||||
|
left: "\\[",
|
||||||
|
right: "\\]",
|
||||||
|
display: !0
|
||||||
|
}], n.ignoredTags = n.ignoredTags || ["script", "noscript", "style", "textarea", "pre", "code", "option"], n.ignoredClasses = n.ignoredClasses || [], n.errorCallback = n.errorCallback || console.error, n.macros = n.macros || {}, s(e, n)
|
||||||
|
}
|
||||||
|
}(), o = o.default
|
||||||
|
}()
|
||||||
|
}));
|
||||||
1
js/katex/auto-render.min.js
vendored
1
js/katex/auto-render.min.js
vendored
|
|
@ -1 +0,0 @@
|
||||||
!function(e,t){"object"==typeof exports&&"object"==typeof module?module.exports=t(require("katex")):"function"==typeof define&&define.amd?define(["katex"],t):"object"==typeof exports?exports.renderMathInElement=t(require("katex")):e.renderMathInElement=t(e.katex)}("undefined"!=typeof self?self:this,(function(e){return function(){"use strict";var t={771:function(t){t.exports=e}},n={};function r(e){var o=n[e];if(void 0!==o)return o.exports;var i=n[e]={exports:{}};return t[e](i,i.exports,r),i.exports}r.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return r.d(t,{a:t}),t},r.d=function(e,t){for(var n in t)r.o(t,n)&&!r.o(e,n)&&Object.defineProperty(e,n,{enumerable:!0,get:t[n]})},r.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)};var o={};return function(){r.d(o,{default:function(){return d}});var e=r(771),t=r.n(e);const n=function(e,t,n){let r=n,o=0;const i=e.length;for(;r<t.length;){const n=t[r];if(o<=0&&t.slice(r,r+i)===e)return r;"\\"===n?r++:"{"===n?o++:"}"===n&&o--,r++}return-1},i=/^\\begin{/;var a=function(e,t){let r;const o=[],a=new RegExp("("+t.map((e=>e.left.replace(/[-/\\^$*+?.()|[\]{}]/g,"\\$&"))).join("|")+")");for(;r=e.search(a),-1!==r;){r>0&&(o.push({type:"text",data:e.slice(0,r)}),e=e.slice(r));const a=t.findIndex((t=>e.startsWith(t.left)));if(r=n(t[a].right,e,t[a].left.length),-1===r)break;const l=e.slice(0,r+t[a].right.length),s=i.test(l)?l:e.slice(t[a].left.length,r);o.push({type:"math",data:s,rawData:l,display:t[a].display}),e=e.slice(r+t[a].right.length)}return""!==e&&o.push({type:"text",data:e}),o};const l=function(e,n){const r=a(e,n.delimiters);if(1===r.length&&"text"===r[0].type)return null;const o=document.createDocumentFragment();for(let e=0;e<r.length;e++)if("text"===r[e].type)o.appendChild(document.createTextNode(r[e].data));else{const i=document.createElement("span");let a=r[e].data;n.displayMode=r[e].display;try{n.preProcess&&(a=n.preProcess(a)),t().render(a,i,n)}catch(i){if(!(i instanceof t().ParseError))throw i;n.errorCallback("KaTeX auto-render: Failed to parse `"+r[e].data+"` with ",i),o.appendChild(document.createTextNode(r[e].rawData));continue}o.appendChild(i)}return o},s=function(e,t){for(let n=0;n<e.childNodes.length;n++){const r=e.childNodes[n];if(3===r.nodeType){let o=r.textContent,i=r.nextSibling,a=0;for(;i&&i.nodeType===Node.TEXT_NODE;)o+=i.textContent,i=i.nextSibling,a++;const s=l(o,t);if(s){for(let e=0;e<a;e++)r.nextSibling.remove();n+=s.childNodes.length-1,e.replaceChild(s,r)}else n+=a}else if(1===r.nodeType){const e=" "+r.className+" ";-1===t.ignoredTags.indexOf(r.nodeName.toLowerCase())&&t.ignoredClasses.every((t=>-1===e.indexOf(" "+t+" ")))&&s(r,t)}}};var d=function(e,t){if(!e)throw new Error("No element provided to render");const n={};for(const e in t)t.hasOwnProperty(e)&&(n[e]=t[e]);n.delimiters=n.delimiters||[{left:"$$",right:"$$",display:!0},{left:"\\(",right:"\\)",display:!1},{left:"\\begin{equation}",right:"\\end{equation}",display:!0},{left:"\\begin{align}",right:"\\end{align}",display:!0},{left:"\\begin{alignat}",right:"\\end{alignat}",display:!0},{left:"\\begin{gather}",right:"\\end{gather}",display:!0},{left:"\\begin{CD}",right:"\\end{CD}",display:!0},{left:"\\[",right:"\\]",display:!0}],n.ignoredTags=n.ignoredTags||["script","noscript","style","textarea","pre","code","option"],n.ignoredClasses=n.ignoredClasses||[],n.errorCallback=n.errorCallback||console.error,n.macros=n.macros||{},s(e,n)}}(),o=o.default}()}));
|
|
||||||
150
js/main.js
150
js/main.js
|
|
@ -2,6 +2,12 @@
|
||||||
// Main
|
// Main
|
||||||
// ------------------------------------------------
|
// ------------------------------------------------
|
||||||
|
|
||||||
|
// Sync highlight.js theme with the actual Gradio theme
|
||||||
|
var defined_hljs_css = document.body.classList.contains("dark") ? "file/css/highlightjs/github-dark.min.css" : "file/css/highlightjs/github.min.css";
|
||||||
|
if (document.getElementById("highlight-css").getAttribute("href") !== defined_hljs_css) {
|
||||||
|
document.getElementById("highlight-css").setAttribute("href", defined_hljs_css);
|
||||||
|
}
|
||||||
|
|
||||||
let main_parent = document.getElementById("chat-tab").parentNode;
|
let main_parent = document.getElementById("chat-tab").parentNode;
|
||||||
let extensions = document.getElementById("extensions");
|
let extensions = document.getElementById("extensions");
|
||||||
|
|
||||||
|
|
@ -145,10 +151,13 @@ targetElement.classList.add("pretty_scrollbar");
|
||||||
targetElement.classList.add("chat-parent");
|
targetElement.classList.add("chat-parent");
|
||||||
window.isScrolled = false;
|
window.isScrolled = false;
|
||||||
let scrollTimeout;
|
let scrollTimeout;
|
||||||
|
let lastScrollTop = 0;
|
||||||
|
let lastScrollHeight = 0;
|
||||||
|
let lastClientHeight = 0;
|
||||||
|
|
||||||
targetElement.addEventListener("scroll", function() {
|
targetElement.addEventListener("scroll", function() {
|
||||||
let diff = targetElement.scrollHeight - targetElement.clientHeight;
|
let diff = targetElement.scrollHeight - targetElement.clientHeight;
|
||||||
let isAtBottomNow = Math.abs(targetElement.scrollTop - diff) <= 10 || diff == 0;
|
let isAtBottomNow = Math.abs(targetElement.scrollTop - diff) <= 10 || diff <= 0;
|
||||||
|
|
||||||
// Add scrolling class to disable hover effects
|
// Add scrolling class to disable hover effects
|
||||||
if (window.isScrolled || !isAtBottomNow) {
|
if (window.isScrolled || !isAtBottomNow) {
|
||||||
|
|
@ -157,9 +166,12 @@ targetElement.addEventListener("scroll", function() {
|
||||||
|
|
||||||
if(isAtBottomNow) {
|
if(isAtBottomNow) {
|
||||||
window.isScrolled = false;
|
window.isScrolled = false;
|
||||||
} else {
|
} else if (targetElement.scrollTop < lastScrollTop && targetElement.scrollHeight >= lastScrollHeight && targetElement.clientHeight <= lastClientHeight) {
|
||||||
window.isScrolled = true;
|
window.isScrolled = true;
|
||||||
}
|
}
|
||||||
|
lastScrollTop = targetElement.scrollTop;
|
||||||
|
lastScrollHeight = targetElement.scrollHeight;
|
||||||
|
lastClientHeight = targetElement.clientHeight;
|
||||||
|
|
||||||
// Clear previous timeout and set new one
|
// Clear previous timeout and set new one
|
||||||
clearTimeout(scrollTimeout);
|
clearTimeout(scrollTimeout);
|
||||||
|
|
@ -170,61 +182,28 @@ targetElement.addEventListener("scroll", function() {
|
||||||
});
|
});
|
||||||
|
|
||||||
// Create a MutationObserver instance
|
// Create a MutationObserver instance
|
||||||
const observer = new MutationObserver(function(mutations) {
|
const observer = new MutationObserver(function() {
|
||||||
// Check if this is just the scrolling class being toggled
|
|
||||||
const isScrollingClassOnly = mutations.every(mutation =>
|
|
||||||
mutation.type === "attributes" &&
|
|
||||||
mutation.attributeName === "class" &&
|
|
||||||
mutation.target === targetElement
|
|
||||||
);
|
|
||||||
|
|
||||||
if (targetElement.classList.contains("_generating")) {
|
if (targetElement.classList.contains("_generating")) {
|
||||||
typing.parentNode.classList.add("visible-dots");
|
typing.parentNode.classList.add("visible-dots");
|
||||||
document.getElementById("stop").style.display = "flex";
|
document.getElementById("stop").style.display = "flex";
|
||||||
document.getElementById("Generate").style.display = "none";
|
document.getElementById("Generate").style.display = "none";
|
||||||
|
// If the user is near the bottom, ensure auto-scroll is enabled
|
||||||
|
// for the new reply. This catches cases where isScrolled was
|
||||||
|
// incorrectly set to true by layout shifts during page load, etc.
|
||||||
|
const diff = targetElement.scrollHeight - targetElement.clientHeight;
|
||||||
|
if (Math.abs(targetElement.scrollTop - diff) <= 10 || diff <= 0) {
|
||||||
|
window.isScrolled = false;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
typing.parentNode.classList.remove("visible-dots");
|
typing.parentNode.classList.remove("visible-dots");
|
||||||
document.getElementById("stop").style.display = "none";
|
document.getElementById("stop").style.display = "none";
|
||||||
document.getElementById("Generate").style.display = "flex";
|
document.getElementById("Generate").style.display = "flex";
|
||||||
}
|
}
|
||||||
|
|
||||||
doSyntaxHighlighting();
|
|
||||||
|
|
||||||
if (!window.isScrolled && !isScrollingClassOnly) {
|
|
||||||
const maxScroll = targetElement.scrollHeight - targetElement.clientHeight;
|
|
||||||
if (maxScroll > 0 && targetElement.scrollTop < maxScroll - 1) {
|
|
||||||
targetElement.scrollTop = maxScroll;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const chatElement = document.getElementById("chat");
|
|
||||||
if (chatElement && chatElement.getAttribute("data-mode") === "instruct") {
|
|
||||||
const messagesContainer = chatElement.querySelector(".messages");
|
|
||||||
const lastChild = messagesContainer?.lastElementChild;
|
|
||||||
const prevSibling = lastChild?.previousElementSibling;
|
|
||||||
if (lastChild && prevSibling) {
|
|
||||||
// Add padding to the messages container to create room for the last message.
|
|
||||||
// The purpose of this is to avoid constant scrolling during streaming in
|
|
||||||
// instruct mode.
|
|
||||||
let bufferHeight = Math.max(0, Math.max(window.innerHeight - 128 - 84, window.innerHeight - prevSibling.offsetHeight - 84) - lastChild.offsetHeight);
|
|
||||||
|
|
||||||
// Subtract header height when screen width is <= 924px
|
|
||||||
if (window.innerWidth <= 924) {
|
|
||||||
bufferHeight = Math.max(0, bufferHeight - 32);
|
|
||||||
}
|
|
||||||
|
|
||||||
messagesContainer.style.paddingBottom = `${bufferHeight}px`;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Configure the observer to watch for changes in the subtree and attributes
|
// Only watch for attribute changes on targetElement (e.g. _generating class)
|
||||||
const config = {
|
const config = {
|
||||||
childList: true,
|
attributes: true
|
||||||
subtree: true,
|
|
||||||
characterData: true,
|
|
||||||
attributeOldValue: true,
|
|
||||||
characterDataOldValue: true
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Start observing the target element
|
// Start observing the target element
|
||||||
|
|
@ -243,13 +222,10 @@ function isElementVisibleOnScreen(element) {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
function doSyntaxHighlighting() {
|
window.doSyntaxHighlighting = function() {
|
||||||
const messageBodies = document.getElementById("chat").querySelectorAll(".message-body");
|
const messageBodies = document.getElementById("chat").querySelectorAll(".message-body");
|
||||||
|
|
||||||
if (messageBodies.length > 0) {
|
if (messageBodies.length > 0) {
|
||||||
observer.disconnect();
|
|
||||||
|
|
||||||
try {
|
|
||||||
let hasSeenVisible = false;
|
let hasSeenVisible = false;
|
||||||
|
|
||||||
// Go from last message to first
|
// Go from last message to first
|
||||||
|
|
@ -274,6 +250,7 @@ function doSyntaxHighlighting() {
|
||||||
renderMathInElement(container, {
|
renderMathInElement(container, {
|
||||||
delimiters: [
|
delimiters: [
|
||||||
{ left: "$$", right: "$$", display: true },
|
{ left: "$$", right: "$$", display: true },
|
||||||
|
{ left: "$", right: "$", display: false },
|
||||||
{ left: "\\(", right: "\\)", display: false },
|
{ left: "\\(", right: "\\)", display: false },
|
||||||
{ left: "\\[", right: "\\]", display: true },
|
{ left: "\\[", right: "\\]", display: true },
|
||||||
],
|
],
|
||||||
|
|
@ -286,20 +263,35 @@ function doSyntaxHighlighting() {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} finally {
|
|
||||||
observer.observe(targetElement, config);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
const doSyntaxHighlighting = window.doSyntaxHighlighting;
|
||||||
|
|
||||||
//------------------------------------------------
|
//------------------------------------------------
|
||||||
// Add some scrollbars
|
// Add some scrollbars
|
||||||
//------------------------------------------------
|
//------------------------------------------------
|
||||||
const textareaElements = document.querySelectorAll(".add_scrollbar textarea");
|
const scrollbarElements = document.querySelectorAll(".add_scrollbar textarea, .add_scrollbar .drag-drop-list");
|
||||||
for(i = 0; i < textareaElements.length; i++) {
|
for(i = 0; i < scrollbarElements.length; i++) {
|
||||||
textareaElements[i].classList.remove("scroll-hide");
|
scrollbarElements[i].classList.remove("scroll-hide");
|
||||||
textareaElements[i].classList.add("pretty_scrollbar");
|
scrollbarElements[i].classList.add("pretty_scrollbar");
|
||||||
textareaElements[i].style.resize = "none";
|
scrollbarElements[i].style.resize = "none";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//------------------------------------------------
|
||||||
|
// Tools: inject "Refresh list" link into the label
|
||||||
|
//------------------------------------------------
|
||||||
|
const toolsTitle = document.querySelector("#tools-group > [data-testid='block-info']");
|
||||||
|
const toolsInfo = toolsTitle ? toolsTitle.nextElementSibling : null;
|
||||||
|
if (toolsInfo) {
|
||||||
|
const refreshLink = document.createElement("span");
|
||||||
|
refreshLink.textContent = " [Refresh list]";
|
||||||
|
refreshLink.className = "tools-refresh-link";
|
||||||
|
refreshLink.addEventListener("click", function(e) {
|
||||||
|
e.preventDefault();
|
||||||
|
document.querySelector("#tools-refresh-btn").click();
|
||||||
|
});
|
||||||
|
toolsInfo.appendChild(refreshLink);
|
||||||
}
|
}
|
||||||
|
|
||||||
//------------------------------------------------
|
//------------------------------------------------
|
||||||
|
|
@ -560,6 +552,38 @@ document.querySelectorAll(".focus-on-chat-input").forEach(element => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
//------------------------------------------------
|
||||||
|
// "New chat" hover menu with incognito option
|
||||||
|
//------------------------------------------------
|
||||||
|
|
||||||
|
(function() {
|
||||||
|
const newChatBtn = document.getElementById("new-chat-btn");
|
||||||
|
|
||||||
|
const wrapper = document.createElement("div");
|
||||||
|
wrapper.id = "new-chat-wrapper";
|
||||||
|
newChatBtn.replaceWith(wrapper);
|
||||||
|
wrapper.appendChild(newChatBtn);
|
||||||
|
|
||||||
|
const arrow = document.createElement("span");
|
||||||
|
arrow.className = "new-chat-arrow";
|
||||||
|
arrow.textContent = "\u25BE";
|
||||||
|
|
||||||
|
const menu = document.createElement("div");
|
||||||
|
menu.className = "new-chat-menu";
|
||||||
|
const option = document.createElement("div");
|
||||||
|
option.className = "new-chat-menu-item";
|
||||||
|
option.textContent = "Incognito chat";
|
||||||
|
menu.appendChild(option);
|
||||||
|
|
||||||
|
arrow.appendChild(menu);
|
||||||
|
wrapper.appendChild(arrow);
|
||||||
|
|
||||||
|
option.addEventListener("click", function(e) {
|
||||||
|
e.stopPropagation();
|
||||||
|
document.querySelector("#incognito-chat-btn").click();
|
||||||
|
});
|
||||||
|
})();
|
||||||
|
|
||||||
//------------------------------------------------
|
//------------------------------------------------
|
||||||
// Fix a border around the "past chats" menu
|
// Fix a border around the "past chats" menu
|
||||||
//------------------------------------------------
|
//------------------------------------------------
|
||||||
|
|
@ -1089,15 +1113,13 @@ document.fonts.addEventListener("loadingdone", (event) => {
|
||||||
const currentHeight = chatInputRow.offsetHeight;
|
const currentHeight = chatInputRow.offsetHeight;
|
||||||
const heightDifference = currentHeight - originalHeight;
|
const heightDifference = currentHeight - originalHeight;
|
||||||
chatParent.style.marginBottom = `${originalMarginBottom + heightDifference}px`;
|
chatParent.style.marginBottom = `${originalMarginBottom + heightDifference}px`;
|
||||||
|
if (!window.isScrolled) {
|
||||||
|
chatParent.scrollTop = chatParent.scrollHeight - chatParent.clientHeight;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Watch for changes that might affect height
|
// Watch for size changes that affect height
|
||||||
const observer = new MutationObserver(updateMargin);
|
new ResizeObserver(updateMargin).observe(chatInputRow);
|
||||||
observer.observe(chatInputRow, {
|
|
||||||
childList: true,
|
|
||||||
subtree: true,
|
|
||||||
attributes: true
|
|
||||||
});
|
|
||||||
|
|
||||||
// Also listen for window resize
|
// Also listen for window resize
|
||||||
window.addEventListener("resize", updateMargin);
|
window.addEventListener("resize", updateMargin);
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,6 @@ from modules.logging_colors import logger
|
||||||
|
|
||||||
|
|
||||||
def add_lora_to_model(lora_names):
|
def add_lora_to_model(lora_names):
|
||||||
if shared.model.__class__.__name__ in ['Exllamav2Model', 'Exllamav2HF'] or shared.args.loader in ['ExLlamav2', 'ExLlamav2_HF']:
|
|
||||||
add_lora_exllamav2(lora_names)
|
|
||||||
else:
|
|
||||||
add_lora_transformers(lora_names)
|
add_lora_transformers(lora_names)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -19,32 +16,6 @@ def get_lora_path(lora_name):
|
||||||
return Path(f"{shared.args.lora_dir}/{lora_name}")
|
return Path(f"{shared.args.lora_dir}/{lora_name}")
|
||||||
|
|
||||||
|
|
||||||
def add_lora_exllamav2(lora_names):
|
|
||||||
|
|
||||||
from exllamav2 import ExLlamaV2Lora
|
|
||||||
|
|
||||||
if isinstance(shared.model.loras, list):
|
|
||||||
for lora in shared.model.loras:
|
|
||||||
lora.unload()
|
|
||||||
|
|
||||||
if len(lora_names) > 0:
|
|
||||||
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
|
|
||||||
shared.model.loras = []
|
|
||||||
for lora_name in lora_names:
|
|
||||||
lora_path = get_lora_path(lora_name)
|
|
||||||
if shared.model.__class__.__name__ == 'Exllamav2Model':
|
|
||||||
lora = ExLlamaV2Lora.from_directory(shared.model.model, str(lora_path))
|
|
||||||
else:
|
|
||||||
lora = ExLlamaV2Lora.from_directory(shared.model.ex_model, str(lora_path))
|
|
||||||
|
|
||||||
shared.model.loras.append(lora)
|
|
||||||
|
|
||||||
shared.lora_names = lora_names
|
|
||||||
else:
|
|
||||||
shared.lora_names = []
|
|
||||||
shared.model.loras = None
|
|
||||||
|
|
||||||
|
|
||||||
def add_lora_transformers(lora_names):
|
def add_lora_transformers(lora_names):
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
|
|
||||||
|
|
@ -77,9 +48,7 @@ def add_lora_transformers(lora_names):
|
||||||
if len(lora_names) > 0:
|
if len(lora_names) > 0:
|
||||||
params = {}
|
params = {}
|
||||||
if not shared.args.cpu:
|
if not shared.args.cpu:
|
||||||
if shared.args.load_in_4bit or shared.args.load_in_8bit:
|
if not shared.args.load_in_4bit and not shared.args.load_in_8bit:
|
||||||
params['peft_type'] = shared.model.dtype
|
|
||||||
else:
|
|
||||||
params['dtype'] = shared.model.dtype
|
params['dtype'] = shared.model.dtype
|
||||||
if hasattr(shared.model, "hf_device_map"):
|
if hasattr(shared.model, "hf_device_map"):
|
||||||
params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}
|
params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}
|
||||||
|
|
|
||||||
|
|
@ -1,96 +0,0 @@
|
||||||
import builtins
|
|
||||||
import io
|
|
||||||
import re
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from modules import shared, ui
|
|
||||||
from modules.logging_colors import logger
|
|
||||||
|
|
||||||
original_open = open
|
|
||||||
original_get = requests.get
|
|
||||||
original_print = print
|
|
||||||
|
|
||||||
|
|
||||||
class RequestBlocker:
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
requests.get = my_get
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
|
||||||
requests.get = original_get
|
|
||||||
|
|
||||||
|
|
||||||
class OpenMonkeyPatch:
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
builtins.open = my_open
|
|
||||||
builtins.print = my_print
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
|
||||||
builtins.open = original_open
|
|
||||||
builtins.print = original_print
|
|
||||||
|
|
||||||
|
|
||||||
def my_get(url, **kwargs):
|
|
||||||
logger.info('Unwanted HTTP request redirected to localhost :)')
|
|
||||||
kwargs.setdefault('allow_redirects', True)
|
|
||||||
return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def my_open(*args, **kwargs):
|
|
||||||
filename = str(args[0])
|
|
||||||
if filename.endswith(('index.html', 'share.html')):
|
|
||||||
with original_open(*args, **kwargs) as f:
|
|
||||||
file_contents = f.read()
|
|
||||||
|
|
||||||
if len(args) > 1 and args[1] == 'rb':
|
|
||||||
file_contents = file_contents.decode('utf-8')
|
|
||||||
|
|
||||||
file_contents = file_contents.replace('\t\t<script\n\t\t\tsrc="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"\n\t\t\tasync\n\t\t></script>', '')
|
|
||||||
file_contents = file_contents.replace('cdnjs.cloudflare.com', '127.0.0.1')
|
|
||||||
file_contents = file_contents.replace(
|
|
||||||
'</head>',
|
|
||||||
'\n <link rel="preload" href="file/css/Inter/Inter-VariableFont_opsz,wght.ttf" as="font" type="font/ttf" crossorigin>'
|
|
||||||
'\n <link rel="preload" href="file/css/Inter/Inter-Italic-VariableFont_opsz,wght.ttf" as="font" type="font/ttf" crossorigin>'
|
|
||||||
'\n <link rel="preload" href="file/css/NotoSans/NotoSans-Medium.woff2" as="font" type="font/woff2" crossorigin>'
|
|
||||||
'\n <link rel="preload" href="file/css/NotoSans/NotoSans-MediumItalic.woff2" as="font" type="font/woff2" crossorigin>'
|
|
||||||
'\n <link rel="preload" href="file/css/NotoSans/NotoSans-Bold.woff2" as="font" type="font/woff2" crossorigin>'
|
|
||||||
'\n <script src="file/js/katex/katex.min.js"></script>'
|
|
||||||
'\n <script src="file/js/katex/auto-render.min.js"></script>'
|
|
||||||
'\n <script src="file/js/highlightjs/highlight.min.js"></script>'
|
|
||||||
'\n <script src="file/js/highlightjs/highlightjs-copy.min.js"></script>'
|
|
||||||
'\n <script src="file/js/morphdom/morphdom-umd.min.js"></script>'
|
|
||||||
f'\n <link id="highlight-css" rel="stylesheet" href="file/css/highlightjs/{"github-dark" if shared.settings["dark_theme"] else "github"}.min.css">'
|
|
||||||
'\n <script>hljs.addPlugin(new CopyButtonPlugin());</script>'
|
|
||||||
f'\n <script>{ui.global_scope_js}</script>'
|
|
||||||
'\n </head>'
|
|
||||||
)
|
|
||||||
|
|
||||||
file_contents = re.sub(
|
|
||||||
r'@media \(prefers-color-scheme: dark\) \{\s*body \{([^}]*)\}\s*\}',
|
|
||||||
r'body.dark {\1}',
|
|
||||||
file_contents,
|
|
||||||
flags=re.DOTALL
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(args) > 1 and args[1] == 'rb':
|
|
||||||
file_contents = file_contents.encode('utf-8')
|
|
||||||
return io.BytesIO(file_contents)
|
|
||||||
else:
|
|
||||||
return io.StringIO(file_contents)
|
|
||||||
|
|
||||||
else:
|
|
||||||
return original_open(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def my_print(*args, **kwargs):
|
|
||||||
if len(args) > 0 and 'To create a public link, set `share=True`' in args[0]:
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
if len(args) > 0 and 'Running on local URL' in args[0]:
|
|
||||||
args = list(args)
|
|
||||||
args[0] = f"\n{args[0].strip()}\n"
|
|
||||||
args = tuple(args)
|
|
||||||
|
|
||||||
original_print(*args, **kwargs)
|
|
||||||
|
|
@ -37,7 +37,7 @@ class Iteratorize:
|
||||||
ret = self.mfunc(callback=_callback, *args, **self.kwargs)
|
ret = self.mfunc(callback=_callback, *args, **self.kwargs)
|
||||||
except StopNowException:
|
except StopNowException:
|
||||||
pass
|
pass
|
||||||
except:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
||||||
850
modules/chat.py
850
modules/chat.py
File diff suppressed because it is too large
Load diff
|
|
@ -1,74 +0,0 @@
|
||||||
def generate_ds_config(ds_bf16, train_batch_size, nvme_offload_dir):
|
|
||||||
'''
|
|
||||||
DeepSpeed configuration
|
|
||||||
https://huggingface.co/docs/transformers/main_classes/deepspeed
|
|
||||||
'''
|
|
||||||
|
|
||||||
if nvme_offload_dir:
|
|
||||||
ds_config = {
|
|
||||||
"fp16": {
|
|
||||||
"enabled": not ds_bf16,
|
|
||||||
},
|
|
||||||
"bf16": {
|
|
||||||
"enabled": ds_bf16,
|
|
||||||
},
|
|
||||||
"zero_optimization": {
|
|
||||||
"stage": 3,
|
|
||||||
"offload_param": {
|
|
||||||
"device": "nvme",
|
|
||||||
"nvme_path": nvme_offload_dir,
|
|
||||||
"pin_memory": True,
|
|
||||||
"buffer_count": 5,
|
|
||||||
"buffer_size": 1e9,
|
|
||||||
"max_in_cpu": 1e9
|
|
||||||
},
|
|
||||||
"overlap_comm": True,
|
|
||||||
"reduce_bucket_size": "auto",
|
|
||||||
"contiguous_gradients": True,
|
|
||||||
"sub_group_size": 1e8,
|
|
||||||
"stage3_prefetch_bucket_size": "auto",
|
|
||||||
"stage3_param_persistence_threshold": "auto",
|
|
||||||
"stage3_max_live_parameters": "auto",
|
|
||||||
"stage3_max_reuse_distance": "auto",
|
|
||||||
},
|
|
||||||
"aio": {
|
|
||||||
"block_size": 262144,
|
|
||||||
"queue_depth": 32,
|
|
||||||
"thread_count": 1,
|
|
||||||
"single_submit": False,
|
|
||||||
"overlap_events": True
|
|
||||||
},
|
|
||||||
"steps_per_print": 2000,
|
|
||||||
"train_batch_size": train_batch_size,
|
|
||||||
"train_micro_batch_size_per_gpu": 1,
|
|
||||||
"wall_clock_breakdown": False
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
ds_config = {
|
|
||||||
"fp16": {
|
|
||||||
"enabled": not ds_bf16,
|
|
||||||
},
|
|
||||||
"bf16": {
|
|
||||||
"enabled": ds_bf16,
|
|
||||||
},
|
|
||||||
"zero_optimization": {
|
|
||||||
"stage": 3,
|
|
||||||
"offload_param": {
|
|
||||||
"device": "cpu",
|
|
||||||
"pin_memory": True
|
|
||||||
},
|
|
||||||
"overlap_comm": True,
|
|
||||||
"contiguous_gradients": True,
|
|
||||||
"reduce_bucket_size": "auto",
|
|
||||||
"stage3_prefetch_bucket_size": "auto",
|
|
||||||
"stage3_param_persistence_threshold": "auto",
|
|
||||||
"stage3_max_live_parameters": "auto",
|
|
||||||
"stage3_max_reuse_distance": "auto",
|
|
||||||
},
|
|
||||||
"steps_per_print": 2000,
|
|
||||||
"train_batch_size": train_batch_size,
|
|
||||||
"train_micro_batch_size_per_gpu": 1,
|
|
||||||
"wall_clock_breakdown": False
|
|
||||||
}
|
|
||||||
|
|
||||||
return ds_config
|
|
||||||
|
|
@ -12,8 +12,8 @@ from modules.text_generation import encode
|
||||||
|
|
||||||
|
|
||||||
def load_past_evaluations():
|
def load_past_evaluations():
|
||||||
if Path('user_data/logs/evaluations.csv').exists():
|
if (shared.user_data_dir / 'logs' / 'evaluations.csv').exists():
|
||||||
df = pd.read_csv(Path('user_data/logs/evaluations.csv'), dtype=str)
|
df = pd.read_csv(shared.user_data_dir / 'logs' / 'evaluations.csv', dtype=str)
|
||||||
df['Perplexity'] = pd.to_numeric(df['Perplexity'])
|
df['Perplexity'] = pd.to_numeric(df['Perplexity'])
|
||||||
return df
|
return df
|
||||||
else:
|
else:
|
||||||
|
|
@ -26,7 +26,7 @@ past_evaluations = load_past_evaluations()
|
||||||
def save_past_evaluations(df):
|
def save_past_evaluations(df):
|
||||||
global past_evaluations
|
global past_evaluations
|
||||||
past_evaluations = df
|
past_evaluations = df
|
||||||
filepath = Path('user_data/logs/evaluations.csv')
|
filepath = shared.user_data_dir / 'logs' / 'evaluations.csv'
|
||||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
df.to_csv(filepath, index=False)
|
df.to_csv(filepath, index=False)
|
||||||
|
|
||||||
|
|
@ -46,10 +46,6 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
|
||||||
logger.error("Perplexity evaluation is not implemented for the llama.cpp loader.")
|
logger.error("Perplexity evaluation is not implemented for the llama.cpp loader.")
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
if shared.args.loader == "ExLlamav2":
|
|
||||||
logger.error("ExLlamav2_HF is required for perplexity evaluation with EXL2 models. Please reload the model with ExLlamav2_HF instead of ExLlamav2.")
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
if not shared.args.no_use_fast:
|
if not shared.args.no_use_fast:
|
||||||
logger.warning("--no_use_fast is not set. If tokenizing the input dataset takes a long time, try reloading the model with that option set/checked.")
|
logger.warning("--no_use_fast is not set. If tokenizing the input dataset takes a long time, try reloading the model with that option set/checked.")
|
||||||
|
|
||||||
|
|
@ -69,7 +65,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
|
||||||
data = load_dataset('ptb_text_only', 'penn_treebank', split='test')
|
data = load_dataset('ptb_text_only', 'penn_treebank', split='test')
|
||||||
text = " ".join(data['sentence'])
|
text = " ".join(data['sentence'])
|
||||||
else:
|
else:
|
||||||
with open(Path(f'user_data/training/datasets/{input_dataset}.txt'), 'r', encoding='utf-8') as f:
|
with open(shared.user_data_dir / 'training' / 'datasets' / f'{input_dataset}.txt', 'r', encoding='utf-8') as f:
|
||||||
text = f.read()
|
text = f.read()
|
||||||
|
|
||||||
for model in models:
|
for model in models:
|
||||||
|
|
@ -86,7 +82,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
|
||||||
update_model_parameters(model_settings) # hijacking the command-line arguments
|
update_model_parameters(model_settings) # hijacking the command-line arguments
|
||||||
unload_model()
|
unload_model()
|
||||||
shared.model, shared.tokenizer = load_model(model)
|
shared.model, shared.tokenizer = load_model(model)
|
||||||
except:
|
except Exception:
|
||||||
cumulative_log += f"Failed to load `{model}`. Moving on.\n\n"
|
cumulative_log += f"Failed to load `{model}`. Moving on.\n\n"
|
||||||
yield cumulative_log
|
yield cumulative_log
|
||||||
continue
|
continue
|
||||||
|
|
|
||||||
|
|
@ -1,247 +0,0 @@
|
||||||
import json
|
|
||||||
import traceback
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from exllamav2 import (
|
|
||||||
ExLlamaV2,
|
|
||||||
ExLlamaV2Cache,
|
|
||||||
ExLlamaV2Cache_8bit,
|
|
||||||
ExLlamaV2Cache_Q4,
|
|
||||||
ExLlamaV2Cache_Q6,
|
|
||||||
ExLlamaV2Cache_Q8,
|
|
||||||
ExLlamaV2Cache_TP,
|
|
||||||
ExLlamaV2Config,
|
|
||||||
ExLlamaV2Tokenizer
|
|
||||||
)
|
|
||||||
from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator
|
|
||||||
|
|
||||||
from modules import shared
|
|
||||||
from modules.logging_colors import logger
|
|
||||||
from modules.text_generation import get_max_prompt_length
|
|
||||||
|
|
||||||
try:
|
|
||||||
import flash_attn
|
|
||||||
except Exception:
|
|
||||||
logger.warning('Failed to load flash-attention due to the following error:\n')
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
|
|
||||||
class Exllamav2Model:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(self, path_to_model):
|
|
||||||
|
|
||||||
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
|
|
||||||
|
|
||||||
config = ExLlamaV2Config()
|
|
||||||
config.model_dir = str(path_to_model)
|
|
||||||
config.prepare()
|
|
||||||
|
|
||||||
config.max_seq_len = shared.args.ctx_size
|
|
||||||
config.scale_pos_emb = shared.args.compress_pos_emb
|
|
||||||
config.scale_alpha_value = shared.args.alpha_value
|
|
||||||
config.no_flash_attn = shared.args.no_flash_attn
|
|
||||||
config.no_xformers = shared.args.no_xformers
|
|
||||||
config.no_sdpa = shared.args.no_sdpa
|
|
||||||
config.num_experts_per_token = int(shared.args.num_experts_per_token)
|
|
||||||
|
|
||||||
model = ExLlamaV2(config)
|
|
||||||
|
|
||||||
split = None
|
|
||||||
if shared.args.gpu_split:
|
|
||||||
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
|
|
||||||
|
|
||||||
if shared.args.enable_tp:
|
|
||||||
model.load_tp(split)
|
|
||||||
elif not shared.args.autosplit:
|
|
||||||
model.load(split)
|
|
||||||
|
|
||||||
# Determine the correct cache type
|
|
||||||
kv_cache_type = shared.args.cache_type.lower()
|
|
||||||
|
|
||||||
if kv_cache_type == 'fp16':
|
|
||||||
cache_type = ExLlamaV2Cache
|
|
||||||
elif kv_cache_type == 'fp8':
|
|
||||||
cache_type = ExLlamaV2Cache_8bit
|
|
||||||
elif kv_cache_type == 'q8':
|
|
||||||
cache_type = ExLlamaV2Cache_Q8
|
|
||||||
elif kv_cache_type == 'q6':
|
|
||||||
cache_type = ExLlamaV2Cache_Q6
|
|
||||||
elif kv_cache_type == 'q4':
|
|
||||||
cache_type = ExLlamaV2Cache_Q4
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid cache type for ExLlamaV2: {cache_type}. Valid options are: fp16, fp8, q8, q6, q4.")
|
|
||||||
|
|
||||||
# Use TP if specified
|
|
||||||
if shared.args.enable_tp:
|
|
||||||
cache = ExLlamaV2Cache_TP(model, base=cache_type)
|
|
||||||
else:
|
|
||||||
cache = cache_type(model, lazy=shared.args.autosplit)
|
|
||||||
|
|
||||||
if shared.args.autosplit and not shared.args.enable_tp:
|
|
||||||
model.load_autosplit(cache)
|
|
||||||
|
|
||||||
tokenizer = ExLlamaV2Tokenizer(config)
|
|
||||||
|
|
||||||
# Initialize draft model for speculative decoding
|
|
||||||
draft_model = None
|
|
||||||
draft_cache = None
|
|
||||||
|
|
||||||
if shared.args.model_draft and shared.args.model_draft.lower() not in ["none", ""]:
|
|
||||||
logger.info(f"Loading draft model for speculative decoding: {shared.args.model_draft}")
|
|
||||||
|
|
||||||
# Find the draft model path
|
|
||||||
draft_path = Path(shared.args.model_draft)
|
|
||||||
if not draft_path.exists():
|
|
||||||
draft_path = Path(f'{shared.args.model_dir}') / Path(shared.args.model_draft)
|
|
||||||
|
|
||||||
draft_config = ExLlamaV2Config()
|
|
||||||
draft_config.model_dir = str(draft_path)
|
|
||||||
draft_config.prepare()
|
|
||||||
draft_config.arch_compat_overrides()
|
|
||||||
|
|
||||||
# Set context size for draft model
|
|
||||||
if shared.args.ctx_size_draft > 0:
|
|
||||||
draft_config.max_seq_len = shared.args.ctx_size_draft
|
|
||||||
else:
|
|
||||||
draft_config.max_seq_len = config.max_seq_len
|
|
||||||
|
|
||||||
draft_model = ExLlamaV2(draft_config)
|
|
||||||
draft_cache = cache_type(draft_model, lazy=True)
|
|
||||||
draft_model.load_autosplit(draft_cache)
|
|
||||||
|
|
||||||
logger.info(f"Draft model loaded successfully with max_draft={shared.args.draft_max}")
|
|
||||||
|
|
||||||
generator = ExLlamaV2StreamingGenerator(
|
|
||||||
model,
|
|
||||||
cache,
|
|
||||||
tokenizer,
|
|
||||||
draft_model=draft_model,
|
|
||||||
draft_cache=draft_cache,
|
|
||||||
num_speculative_tokens=shared.args.draft_max if draft_model is not None else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
result = self()
|
|
||||||
result.model = model
|
|
||||||
result.cache = cache
|
|
||||||
result.tokenizer = tokenizer
|
|
||||||
result.generator = generator
|
|
||||||
result.loras = None
|
|
||||||
result.draft_model = draft_model
|
|
||||||
result.draft_cache = draft_cache
|
|
||||||
return result, result
|
|
||||||
|
|
||||||
def encode(self, string, **kwargs):
|
|
||||||
add_bos = kwargs.pop('add_bos', True)
|
|
||||||
return self.tokenizer.encode(string, add_bos=add_bos, encode_special_tokens=True, **kwargs)
|
|
||||||
|
|
||||||
def decode(self, ids, **kwargs):
|
|
||||||
if isinstance(ids, list):
|
|
||||||
ids = torch.tensor([ids])
|
|
||||||
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
|
||||||
ids = ids.view(1, -1)
|
|
||||||
|
|
||||||
return self.tokenizer.decode(ids, decode_special_tokens=True)[0]
|
|
||||||
|
|
||||||
def get_logits(self, token_ids, **kwargs):
|
|
||||||
self.cache.current_seq_len = 0
|
|
||||||
if token_ids.shape[-1] > 1:
|
|
||||||
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True, loras=self.loras)
|
|
||||||
|
|
||||||
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, loras=self.loras, **kwargs).float().cpu()
|
|
||||||
|
|
||||||
def generate_with_streaming(self, prompt, state):
|
|
||||||
settings = ExLlamaV2Sampler.Settings()
|
|
||||||
|
|
||||||
settings.token_repetition_penalty = state['repetition_penalty']
|
|
||||||
settings.token_repetition_range = -1 if state['repetition_penalty_range'] <= 0 else state['repetition_penalty_range']
|
|
||||||
|
|
||||||
settings.token_frequency_penalty = state['frequency_penalty']
|
|
||||||
settings.token_presence_penalty = state['presence_penalty']
|
|
||||||
|
|
||||||
settings.temperature = state['temperature']
|
|
||||||
settings.smoothing_factor = state['smoothing_factor']
|
|
||||||
settings.min_temp = state['dynatemp_low'] if state['dynamic_temperature'] else 0
|
|
||||||
settings.max_temp = state['dynatemp_high'] if state['dynamic_temperature'] else 0
|
|
||||||
settings.temp_exponent = state['dynatemp_exponent']
|
|
||||||
settings.top_k = state['top_k']
|
|
||||||
settings.top_p = state['top_p']
|
|
||||||
settings.top_a = state['top_a']
|
|
||||||
settings.min_p = state['min_p']
|
|
||||||
settings.tfs = state['tfs']
|
|
||||||
settings.typical = state['typical_p']
|
|
||||||
|
|
||||||
settings.temperature_last = state['temperature_last']
|
|
||||||
|
|
||||||
settings.mirostat = state['mirostat_mode'] == 2
|
|
||||||
settings.mirostat_tau = state['mirostat_tau']
|
|
||||||
settings.mirostat_eta = state['mirostat_eta']
|
|
||||||
|
|
||||||
if state['ban_eos_token']:
|
|
||||||
settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
|
|
||||||
|
|
||||||
if state['custom_token_bans']:
|
|
||||||
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
|
|
||||||
if len(to_ban) > 0:
|
|
||||||
settings.disallow_tokens(self.tokenizer, to_ban)
|
|
||||||
|
|
||||||
settings.dry_allowed_length = state['dry_allowed_length']
|
|
||||||
settings.dry_base = state['dry_base']
|
|
||||||
settings.dry_multiplier = state['dry_multiplier']
|
|
||||||
|
|
||||||
# Dry sequence breakers processing
|
|
||||||
if state['dry_multiplier'] > 0 and state['dry_sequence_breakers']:
|
|
||||||
dry_sequence_breakers = state['dry_sequence_breakers']
|
|
||||||
|
|
||||||
# Support both JSON array notation and comma-separated strings.
|
|
||||||
if not dry_sequence_breakers.startswith("["):
|
|
||||||
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
|
|
||||||
|
|
||||||
sequence_breaker_strings = json.loads(dry_sequence_breakers)
|
|
||||||
# Prefix with 'a' to get the correct encoding of the token at the end of a text.
|
|
||||||
sequence_breakers = {
|
|
||||||
self.encode(f"a{s}")[0, -1].item() for s in sequence_breaker_strings
|
|
||||||
}
|
|
||||||
|
|
||||||
settings.dry_sequence_breakers = sequence_breakers
|
|
||||||
|
|
||||||
settings.xtc_probability = state['xtc_probability']
|
|
||||||
settings.xtc_threshold = state['xtc_threshold']
|
|
||||||
|
|
||||||
ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True)
|
|
||||||
ids = ids[:, -get_max_prompt_length(state):]
|
|
||||||
|
|
||||||
if state['auto_max_new_tokens']:
|
|
||||||
max_new_tokens = state['truncation_length'] - ids.shape[-1]
|
|
||||||
else:
|
|
||||||
max_new_tokens = state['max_new_tokens']
|
|
||||||
|
|
||||||
# Reset speculative decoding stats if using a draft model
|
|
||||||
if hasattr(self, 'draft_model') and self.draft_model is not None:
|
|
||||||
self.generator.reset_sd_stats()
|
|
||||||
|
|
||||||
self.generator.begin_stream(ids, settings, loras=self.loras)
|
|
||||||
|
|
||||||
decoded_text = ''
|
|
||||||
for i in range(max_new_tokens):
|
|
||||||
chunk, eos, _ = self.generator.stream()
|
|
||||||
if eos or shared.stop_everything:
|
|
||||||
break
|
|
||||||
|
|
||||||
decoded_text += chunk
|
|
||||||
yield decoded_text
|
|
||||||
|
|
||||||
# Log speculative decoding stats if using draft model
|
|
||||||
if hasattr(self, 'draft_model') and self.draft_model is not None:
|
|
||||||
efficiency, accuracy, total_tokens, total_draft_tokens, accepted_draft_tokens = self.generator.get_sd_stats()
|
|
||||||
logger.info(f"Speculative decoding: accepted={accepted_draft_tokens}/{total_draft_tokens} tokens")
|
|
||||||
|
|
||||||
def generate(self, prompt, state):
|
|
||||||
output = ''
|
|
||||||
for output in self.generate_with_streaming(prompt, state):
|
|
||||||
pass
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
@ -1,203 +0,0 @@
|
||||||
import os
|
|
||||||
import traceback
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from exllamav2 import (
|
|
||||||
ExLlamaV2,
|
|
||||||
ExLlamaV2Cache,
|
|
||||||
ExLlamaV2Cache_8bit,
|
|
||||||
ExLlamaV2Cache_Q4,
|
|
||||||
ExLlamaV2Cache_Q6,
|
|
||||||
ExLlamaV2Cache_Q8,
|
|
||||||
ExLlamaV2Cache_TP,
|
|
||||||
ExLlamaV2Config
|
|
||||||
)
|
|
||||||
from torch.nn import CrossEntropyLoss
|
|
||||||
from transformers import (
|
|
||||||
GenerationConfig,
|
|
||||||
GenerationMixin,
|
|
||||||
PretrainedConfig,
|
|
||||||
PreTrainedModel
|
|
||||||
)
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
|
|
||||||
from modules import shared
|
|
||||||
from modules.logging_colors import logger
|
|
||||||
|
|
||||||
try:
|
|
||||||
import flash_attn
|
|
||||||
except Exception:
|
|
||||||
logger.warning('Failed to load flash-attention due to the following error:\n')
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
|
|
||||||
class Exllamav2HF(PreTrainedModel, GenerationMixin):
|
|
||||||
def __init__(self, config: ExLlamaV2Config):
|
|
||||||
hf_config = PretrainedConfig.from_pretrained(config.model_dir)
|
|
||||||
super().__init__(hf_config)
|
|
||||||
|
|
||||||
self.ex_config = config
|
|
||||||
self.loras = None
|
|
||||||
self.generation_config = GenerationConfig()
|
|
||||||
|
|
||||||
self.ex_model = ExLlamaV2(config)
|
|
||||||
|
|
||||||
split = None
|
|
||||||
if shared.args.gpu_split:
|
|
||||||
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
|
|
||||||
|
|
||||||
if shared.args.enable_tp:
|
|
||||||
self.ex_model.load_tp(split)
|
|
||||||
elif not shared.args.autosplit:
|
|
||||||
self.ex_model.load(split)
|
|
||||||
|
|
||||||
# Determine the correct cache type
|
|
||||||
kv_cache_type = shared.args.cache_type.lower()
|
|
||||||
|
|
||||||
if kv_cache_type == 'fp16':
|
|
||||||
cache_type = ExLlamaV2Cache
|
|
||||||
elif kv_cache_type == 'fp8':
|
|
||||||
cache_type = ExLlamaV2Cache_8bit
|
|
||||||
elif kv_cache_type == 'q8':
|
|
||||||
cache_type = ExLlamaV2Cache_Q8
|
|
||||||
elif kv_cache_type == 'q6':
|
|
||||||
cache_type = ExLlamaV2Cache_Q6
|
|
||||||
elif kv_cache_type == 'q4':
|
|
||||||
cache_type = ExLlamaV2Cache_Q4
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid cache type for ExLlamaV2: {kv_cache_type}. Valid options are: fp16, fp8, q8, q6, q4.")
|
|
||||||
|
|
||||||
# Use TP if specified
|
|
||||||
if shared.args.enable_tp:
|
|
||||||
self.ex_cache = ExLlamaV2Cache_TP(self.ex_model, base=cache_type)
|
|
||||||
else:
|
|
||||||
self.ex_cache = cache_type(self.ex_model, lazy=shared.args.autosplit)
|
|
||||||
|
|
||||||
if shared.args.autosplit and not shared.args.enable_tp:
|
|
||||||
self.ex_model.load_autosplit(self.ex_cache)
|
|
||||||
|
|
||||||
self.past_seq = None
|
|
||||||
if shared.args.cfg_cache:
|
|
||||||
if shared.args.enable_tp:
|
|
||||||
self.ex_cache_negative = ExLlamaV2Cache_TP(self.ex_model, base=cache_type)
|
|
||||||
else:
|
|
||||||
self.ex_cache_negative = cache_type(self.ex_model, lazy=shared.args.autosplit)
|
|
||||||
|
|
||||||
self.past_seq_negative = None
|
|
||||||
|
|
||||||
def _validate_model_class(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
|
||||||
return {'input_ids': input_ids, **kwargs}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self) -> torch.device:
|
|
||||||
return torch.device(0)
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
use_cache = kwargs.get('use_cache', True)
|
|
||||||
labels = kwargs.get('labels', None)
|
|
||||||
past_key_values = kwargs.get('past_key_values', None)
|
|
||||||
|
|
||||||
if len(args) > 0:
|
|
||||||
if not shared.args.cfg_cache:
|
|
||||||
logger.error("Please enable the cfg-cache option to use CFG with ExLlamav2_HF.")
|
|
||||||
return
|
|
||||||
|
|
||||||
input_ids = args[0]
|
|
||||||
is_negative = True
|
|
||||||
past_seq = self.past_seq_negative
|
|
||||||
ex_cache = self.ex_cache_negative
|
|
||||||
else:
|
|
||||||
input_ids = kwargs['input_ids']
|
|
||||||
is_negative = False
|
|
||||||
past_seq = self.past_seq
|
|
||||||
ex_cache = self.ex_cache
|
|
||||||
|
|
||||||
seq = input_ids[0].tolist()
|
|
||||||
if is_negative and past_key_values is not None:
|
|
||||||
seq = past_key_values + seq
|
|
||||||
|
|
||||||
seq_tensor = torch.tensor(seq)
|
|
||||||
reset = True
|
|
||||||
|
|
||||||
# Make the forward call
|
|
||||||
if labels is None:
|
|
||||||
if past_seq is not None:
|
|
||||||
min_length = min(past_seq.shape[0], seq_tensor.shape[0])
|
|
||||||
indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
|
|
||||||
if len(indices) > 0:
|
|
||||||
longest_prefix = indices[0].item()
|
|
||||||
else:
|
|
||||||
longest_prefix = min_length
|
|
||||||
|
|
||||||
if longest_prefix > 0:
|
|
||||||
reset = False
|
|
||||||
ex_cache.current_seq_len = longest_prefix
|
|
||||||
if len(seq_tensor) - longest_prefix > 1:
|
|
||||||
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
|
|
||||||
elif len(seq_tensor) == longest_prefix:
|
|
||||||
# Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one,
|
|
||||||
# because we feed input_ids[-1] to forward() below, but that last token is already in the cache!
|
|
||||||
ex_cache.current_seq_len -= 1
|
|
||||||
|
|
||||||
if reset:
|
|
||||||
ex_cache.current_seq_len = 0
|
|
||||||
if len(seq_tensor) > 1:
|
|
||||||
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
|
|
||||||
|
|
||||||
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device).float()
|
|
||||||
else:
|
|
||||||
ex_cache.current_seq_len = 0
|
|
||||||
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras).float()
|
|
||||||
|
|
||||||
if is_negative:
|
|
||||||
self.past_seq_negative = seq_tensor
|
|
||||||
else:
|
|
||||||
self.past_seq = seq_tensor
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
shift_logits = shift_logits.view(-1, logits.shape[-1])
|
|
||||||
shift_labels = shift_labels.view(-1)
|
|
||||||
# Enable model parallelism
|
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
|
||||||
loss = loss_fct(shift_logits, shift_labels)
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
|
|
||||||
assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported"
|
|
||||||
if isinstance(pretrained_model_name_or_path, str):
|
|
||||||
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
|
|
||||||
|
|
||||||
pretrained_model_name_or_path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path)
|
|
||||||
|
|
||||||
config = ExLlamaV2Config()
|
|
||||||
config.model_dir = str(pretrained_model_name_or_path)
|
|
||||||
config.prepare()
|
|
||||||
|
|
||||||
config.max_seq_len = shared.args.ctx_size
|
|
||||||
config.scale_pos_emb = shared.args.compress_pos_emb
|
|
||||||
config.scale_alpha_value = shared.args.alpha_value
|
|
||||||
config.no_flash_attn = shared.args.no_flash_attn
|
|
||||||
config.no_xformers = shared.args.no_xformers
|
|
||||||
config.no_sdpa = shared.args.no_sdpa
|
|
||||||
config.num_experts_per_token = int(shared.args.num_experts_per_token)
|
|
||||||
|
|
||||||
return Exllamav2HF(config)
|
|
||||||
|
|
@ -1,3 +1,6 @@
|
||||||
|
import math
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
import traceback
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, Tuple
|
from typing import Any, List, Tuple
|
||||||
|
|
@ -7,8 +10,10 @@ import torch
|
||||||
from exllamav3 import Cache, Config, Generator, Model, Tokenizer
|
from exllamav3 import Cache, Config, Generator, Model, Tokenizer
|
||||||
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
|
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
|
||||||
from exllamav3.generator import Job
|
from exllamav3.generator import Job
|
||||||
|
from exllamav3.generator.filter import Filter
|
||||||
from exllamav3.generator.sampler import (
|
from exllamav3.generator.sampler import (
|
||||||
CustomSampler,
|
CustomSampler,
|
||||||
|
SS_AdaptiveP,
|
||||||
SS_Argmax,
|
SS_Argmax,
|
||||||
SS_MinP,
|
SS_MinP,
|
||||||
SS_PresFreqP,
|
SS_PresFreqP,
|
||||||
|
|
@ -33,10 +38,95 @@ except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
class LogitBiasFilter(Filter):
|
||||||
|
"""Filter subclass that applies a static additive logit bias mask."""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer, logit_bias_dict):
|
||||||
|
super().__init__(tokenizer=tokenizer, trigger_token=None, prefix_str=None, eos_after_completed=False)
|
||||||
|
self.logit_bias_dict = logit_bias_dict
|
||||||
|
self._mask = None
|
||||||
|
|
||||||
|
def reset(self): pass
|
||||||
|
def accept_token(self, token): pass
|
||||||
|
def is_completed(self): return False
|
||||||
|
def use_background_worker(self): return False
|
||||||
|
|
||||||
|
def get_next_logit_mask(self):
|
||||||
|
if self._mask is None:
|
||||||
|
self._mask = torch.zeros((1, self.vocab_size), dtype=self.logits_dtype)
|
||||||
|
for token_id_str, bias in self.logit_bias_dict.items():
|
||||||
|
token_id = int(token_id_str)
|
||||||
|
if 0 <= token_id < self.vocab_size:
|
||||||
|
self._mask[0, token_id] = bias
|
||||||
|
return self._mask
|
||||||
|
|
||||||
|
|
||||||
|
class ConcurrentGenerator:
|
||||||
|
def __init__(self, generator):
|
||||||
|
self.generator = generator
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
self.job_queues = {}
|
||||||
|
self.active = True
|
||||||
|
self.has_jobs = threading.Event()
|
||||||
|
self.thread = threading.Thread(target=self._iterate_loop, daemon=True)
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
|
def _iterate_loop(self):
|
||||||
|
while self.active:
|
||||||
|
self.has_jobs.wait(timeout=0.5)
|
||||||
|
with self.lock:
|
||||||
|
if not self.job_queues:
|
||||||
|
self.has_jobs.clear()
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
results = self.generator.iterate()
|
||||||
|
except Exception:
|
||||||
|
logger.error("Exception in ConcurrentGenerator iterate loop:\n" + traceback.format_exc())
|
||||||
|
for q in self.job_queues.values():
|
||||||
|
q.put(None)
|
||||||
|
self.job_queues.clear()
|
||||||
|
self.generator.clear_queue()
|
||||||
|
self.has_jobs.clear()
|
||||||
|
continue
|
||||||
|
for result in results:
|
||||||
|
job = result["job"]
|
||||||
|
q = self.job_queues.get(job)
|
||||||
|
if q:
|
||||||
|
q.put(result)
|
||||||
|
if result.get("eos"):
|
||||||
|
self.job_queues.pop(job, None)
|
||||||
|
if not self.job_queues:
|
||||||
|
self.has_jobs.clear()
|
||||||
|
|
||||||
|
def submit(self, job) -> queue.Queue:
|
||||||
|
q = queue.Queue()
|
||||||
|
with self.lock:
|
||||||
|
self.job_queues[job] = q
|
||||||
|
self.generator.enqueue(job)
|
||||||
|
self.has_jobs.set()
|
||||||
|
return q
|
||||||
|
|
||||||
|
def cancel(self, job):
|
||||||
|
with self.lock:
|
||||||
|
if job in self.job_queues:
|
||||||
|
self.generator.cancel(job)
|
||||||
|
self.job_queues[job].put(None)
|
||||||
|
del self.job_queues[job]
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self.active = False
|
||||||
|
self.has_jobs.set()
|
||||||
|
self.thread.join(timeout=5)
|
||||||
|
|
||||||
|
|
||||||
class Exllamav3Model:
|
class Exllamav3Model:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return torch.device(0)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, path_to_model):
|
def from_pretrained(cls, path_to_model):
|
||||||
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
|
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
|
||||||
|
|
@ -58,7 +148,7 @@ class Exllamav3Model:
|
||||||
logger.warning(f"max_num_tokens must be a multiple of 256. Adjusting from {max_tokens} to {adjusted_tokens}")
|
logger.warning(f"max_num_tokens must be a multiple of 256. Adjusting from {max_tokens} to {adjusted_tokens}")
|
||||||
max_tokens = adjusted_tokens
|
max_tokens = adjusted_tokens
|
||||||
|
|
||||||
# Parse cache type (ExLlamaV2 pattern)
|
# Parse cache type
|
||||||
cache_type = shared.args.cache_type.lower()
|
cache_type = shared.args.cache_type.lower()
|
||||||
cache_kwargs = {}
|
cache_kwargs = {}
|
||||||
if cache_type == 'fp16':
|
if cache_type == 'fp16':
|
||||||
|
|
@ -97,8 +187,21 @@ class Exllamav3Model:
|
||||||
load_params['tensor_p'] = True
|
load_params['tensor_p'] = True
|
||||||
load_params['tp_backend'] = shared.args.tp_backend
|
load_params['tp_backend'] = shared.args.tp_backend
|
||||||
|
|
||||||
model.load(**load_params)
|
# Load vision and draft before the main model so autosplit
|
||||||
tokenizer = Tokenizer.from_config(config)
|
# accounts for their VRAM usage.
|
||||||
|
|
||||||
|
# Load vision model component (ExLlamaV3 native)
|
||||||
|
vision_model = None
|
||||||
|
if "vision_config" in config.config_dict:
|
||||||
|
logger.info("Vision component detected in model config. Attempting to load...")
|
||||||
|
try:
|
||||||
|
vision_model = Model.from_config(config, component="vision")
|
||||||
|
vision_model.load(progressbar=True)
|
||||||
|
logger.info("Vision model loaded successfully.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Vision model loading failed (multimodal disabled): {e}")
|
||||||
|
else:
|
||||||
|
logger.info("No vision component in model config. Skipping multimodal setup.")
|
||||||
|
|
||||||
# Initialize draft model for speculative decoding
|
# Initialize draft model for speculative decoding
|
||||||
draft_model = None
|
draft_model = None
|
||||||
|
|
@ -114,23 +217,8 @@ class Exllamav3Model:
|
||||||
logger.warning(f"Draft model not found at {draft_path}, speculative decoding disabled.")
|
logger.warning(f"Draft model not found at {draft_path}, speculative decoding disabled.")
|
||||||
else:
|
else:
|
||||||
draft_config = Config.from_directory(str(draft_path))
|
draft_config = Config.from_directory(str(draft_path))
|
||||||
|
|
||||||
# Set context size for draft model with 256-multiple validation
|
|
||||||
if shared.args.ctx_size_draft > 0:
|
|
||||||
draft_max_tokens = shared.args.ctx_size_draft
|
|
||||||
else:
|
|
||||||
draft_max_tokens = shared.args.ctx_size
|
|
||||||
|
|
||||||
# Validate draft model context size is a multiple of 256
|
|
||||||
if draft_max_tokens % 256 != 0:
|
|
||||||
adjusted_draft_tokens = ((draft_max_tokens // 256) + 1) * 256
|
|
||||||
logger.warning(f"Draft model max_num_tokens must be a multiple of 256. Adjusting from {draft_max_tokens} to {adjusted_draft_tokens}")
|
|
||||||
draft_max_tokens = adjusted_draft_tokens
|
|
||||||
|
|
||||||
draft_config.max_seq_len = draft_max_tokens
|
|
||||||
|
|
||||||
draft_model = Model.from_config(draft_config)
|
draft_model = Model.from_config(draft_config)
|
||||||
draft_cache = Cache(draft_model, max_num_tokens=draft_max_tokens, layer_type=layer_type, **cache_kwargs)
|
draft_cache = Cache(draft_model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs)
|
||||||
|
|
||||||
draft_load_params = {'progressbar': True}
|
draft_load_params = {'progressbar': True}
|
||||||
if split:
|
if split:
|
||||||
|
|
@ -139,18 +227,9 @@ class Exllamav3Model:
|
||||||
draft_model.load(**draft_load_params)
|
draft_model.load(**draft_load_params)
|
||||||
logger.info(f"Draft model loaded successfully. Max speculative tokens: {shared.args.draft_max}")
|
logger.info(f"Draft model loaded successfully. Max speculative tokens: {shared.args.draft_max}")
|
||||||
|
|
||||||
# Load vision model component (ExLlamaV3 native)
|
# Load main model last
|
||||||
vision_model = None
|
model.load(**load_params)
|
||||||
if "vision_config" in config.config_dict:
|
tokenizer = Tokenizer.from_config(config)
|
||||||
logger.info("Vision component detected in model config. Attempting to load...")
|
|
||||||
try:
|
|
||||||
vision_model = Model.from_config(config, component="vision")
|
|
||||||
vision_model.load(progressbar=True)
|
|
||||||
logger.info("Vision model loaded successfully.")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Vision model loading failed (multimodal disabled): {e}")
|
|
||||||
else:
|
|
||||||
logger.info("No vision component in model config. Skipping multimodal setup.")
|
|
||||||
|
|
||||||
generator = Generator(
|
generator = Generator(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
@ -158,7 +237,7 @@ class Exllamav3Model:
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
draft_model=draft_model,
|
draft_model=draft_model,
|
||||||
draft_cache=draft_cache,
|
draft_cache=draft_cache,
|
||||||
num_speculative_tokens=shared.args.draft_max if draft_model is not None else 0,
|
num_draft_tokens=shared.args.draft_max if draft_model is not None else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = cls()
|
result = cls()
|
||||||
|
|
@ -166,6 +245,7 @@ class Exllamav3Model:
|
||||||
result.cache = cache
|
result.cache = cache
|
||||||
result.tokenizer = tokenizer
|
result.tokenizer = tokenizer
|
||||||
result.generator = generator
|
result.generator = generator
|
||||||
|
result.parallel_generator = ConcurrentGenerator(generator)
|
||||||
result.config = config
|
result.config = config
|
||||||
result.max_tokens = max_tokens
|
result.max_tokens = max_tokens
|
||||||
result.vision_model = vision_model
|
result.vision_model = vision_model
|
||||||
|
|
@ -286,11 +366,16 @@ class Exllamav3Model:
|
||||||
|
|
||||||
# 3. Get the priority list and handle temperature_last
|
# 3. Get the priority list and handle temperature_last
|
||||||
default_priority = ['repetition_penalty', 'presence_frequency_penalty', 'top_k', 'top_p', 'min_p', 'temperature']
|
default_priority = ['repetition_penalty', 'presence_frequency_penalty', 'top_k', 'top_p', 'min_p', 'temperature']
|
||||||
sampler_priority = state.get('sampler_priority') or default_priority
|
sampler_priority = list(state.get('sampler_priority') or default_priority)
|
||||||
|
|
||||||
if state['temperature_last'] and 'temperature' in sampler_priority:
|
if state['temperature_last'] and 'temperature' in sampler_priority:
|
||||||
sampler_priority.append(sampler_priority.pop(sampler_priority.index('temperature')))
|
sampler_priority.append(sampler_priority.pop(sampler_priority.index('temperature')))
|
||||||
|
|
||||||
|
# The preset system uses separate 'presence_penalty' and
|
||||||
|
# 'frequency_penalty', but ExLlamaV3 has a single combined
|
||||||
|
# SS_PresFreqP sampler. Normalize to the combined name.
|
||||||
|
sampler_priority = ['presence_frequency_penalty' if x in ('presence_penalty', 'frequency_penalty') else x for x in sampler_priority]
|
||||||
|
|
||||||
# 4. Sort the unordered list based on the priority list
|
# 4. Sort the unordered list based on the priority list
|
||||||
def custom_sort_key(sampler_obj):
|
def custom_sort_key(sampler_obj):
|
||||||
class_name = sampler_obj.__class__.__name__
|
class_name = sampler_obj.__class__.__name__
|
||||||
|
|
@ -302,7 +387,11 @@ class Exllamav3Model:
|
||||||
ordered_samplers = sorted(unordered_samplers, key=custom_sort_key)
|
ordered_samplers = sorted(unordered_samplers, key=custom_sort_key)
|
||||||
|
|
||||||
# 5. Add the final sampling stage and build the sampler
|
# 5. Add the final sampling stage and build the sampler
|
||||||
|
if state.get('adaptive_target', 0) > 0:
|
||||||
|
ordered_samplers.append(SS_AdaptiveP(state['adaptive_target'], state['adaptive_decay']))
|
||||||
|
else:
|
||||||
ordered_samplers.append(SS_Sample())
|
ordered_samplers.append(SS_Sample())
|
||||||
|
|
||||||
sampler = CustomSampler(ordered_samplers)
|
sampler = CustomSampler(ordered_samplers)
|
||||||
|
|
||||||
# Encode prompt with embeddings (ExLlamaV3-specific)
|
# Encode prompt with embeddings (ExLlamaV3-specific)
|
||||||
|
|
@ -323,43 +412,86 @@ class Exllamav3Model:
|
||||||
else:
|
else:
|
||||||
max_new_tokens = state['max_new_tokens']
|
max_new_tokens = state['max_new_tokens']
|
||||||
|
|
||||||
# Get stop conditions
|
# Use full EOS token list from config (may contain multiple IDs)
|
||||||
stop_conditions = []
|
stop_conditions = []
|
||||||
if not state['ban_eos_token']:
|
if not state['ban_eos_token']:
|
||||||
if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
|
for eos_id in self.config.eos_token_id_list:
|
||||||
stop_conditions.append(self.tokenizer.eos_token_id)
|
if eos_id is not None:
|
||||||
|
stop_conditions.append(eos_id)
|
||||||
|
|
||||||
|
# Build filters for logit_bias (OpenAI API)
|
||||||
|
filters = []
|
||||||
|
logit_bias = state.get('logit_bias')
|
||||||
|
if logit_bias:
|
||||||
|
filters.append(LogitBiasFilter(self.tokenizer, logit_bias))
|
||||||
|
|
||||||
|
# Logprobs support (OpenAI API)
|
||||||
|
logprobs = state.get('logprobs', 0) or 0
|
||||||
|
return_top_tokens = logprobs if logprobs > 0 else 0
|
||||||
|
|
||||||
|
seed = state.get('seed', -1)
|
||||||
job = Job(
|
job = Job(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
decode_special_tokens=not state['skip_special_tokens'],
|
decode_special_tokens=not state['skip_special_tokens'],
|
||||||
embeddings=image_embeddings if image_embeddings else None,
|
embeddings=image_embeddings if image_embeddings else None,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
|
seed=seed if seed >= 0 else None,
|
||||||
stop_conditions=stop_conditions if stop_conditions else None,
|
stop_conditions=stop_conditions if stop_conditions else None,
|
||||||
|
filters=filters if filters else None,
|
||||||
|
return_top_tokens=return_top_tokens,
|
||||||
|
return_probs=return_top_tokens > 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Stream generation
|
# Stream generation
|
||||||
self.generator.enqueue(job)
|
|
||||||
|
|
||||||
response_text = ""
|
response_text = ""
|
||||||
|
stop_event = state.get('stop_event')
|
||||||
|
self.last_completion_probabilities = []
|
||||||
|
|
||||||
|
result_queue = self.parallel_generator.submit(job)
|
||||||
try:
|
try:
|
||||||
while self.generator.num_remaining_jobs():
|
while True:
|
||||||
if shared.stop_everything:
|
if shared.stop_everything or (stop_event and stop_event.is_set()):
|
||||||
break
|
break
|
||||||
|
try:
|
||||||
results = self.generator.iterate()
|
result = result_queue.get(timeout=0.1)
|
||||||
for result in results:
|
except queue.Empty:
|
||||||
if "eos" in result and result["eos"]:
|
continue
|
||||||
|
if result is None or result.get("eos"):
|
||||||
|
# Capture logprobs from the final eos result too
|
||||||
|
if result is not None and return_top_tokens > 0:
|
||||||
|
self._capture_logprobs(result)
|
||||||
break
|
break
|
||||||
|
|
||||||
chunk = result.get("text", "")
|
chunk = result.get("text", "")
|
||||||
|
|
||||||
|
# Capture logprobs from streaming results
|
||||||
|
if return_top_tokens > 0:
|
||||||
|
self._capture_logprobs(result)
|
||||||
|
|
||||||
if chunk:
|
if chunk:
|
||||||
response_text += chunk
|
response_text += chunk
|
||||||
yield response_text
|
yield response_text
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
self.generator.clear_queue()
|
self.parallel_generator.cancel(job)
|
||||||
|
|
||||||
|
def _capture_logprobs(self, result):
|
||||||
|
"""Convert ExLlamav3 top-k token data to the shared logprobs format."""
|
||||||
|
top_k_tokens = result.get("top_k_tokens")
|
||||||
|
top_k_probs = result.get("top_k_probs")
|
||||||
|
if top_k_tokens is None or top_k_probs is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
id_to_piece = self.tokenizer.get_id_to_piece_list(True)
|
||||||
|
# top_k_tokens shape: (batch, seq_len, k), top_k_probs same
|
||||||
|
for seq_idx in range(top_k_tokens.shape[1]):
|
||||||
|
entry = {"top_logprobs": []}
|
||||||
|
for k_idx in range(top_k_tokens.shape[2]):
|
||||||
|
token_id = top_k_tokens[0, seq_idx, k_idx].item()
|
||||||
|
prob = top_k_probs[0, seq_idx, k_idx].item()
|
||||||
|
token_str = id_to_piece[token_id] if token_id < len(id_to_piece) else f"<{token_id}>"
|
||||||
|
logprob = math.log(prob) if prob > 0 else float("-inf")
|
||||||
|
entry["top_logprobs"].append({"token": token_str, "logprob": logprob})
|
||||||
|
self.last_completion_probabilities.append(entry)
|
||||||
|
|
||||||
def generate(self, prompt, state):
|
def generate(self, prompt, state):
|
||||||
output = ""
|
output = ""
|
||||||
|
|
@ -422,6 +554,13 @@ class Exllamav3Model:
|
||||||
def unload(self):
|
def unload(self):
|
||||||
logger.info("Unloading ExLlamaV3 model components...")
|
logger.info("Unloading ExLlamaV3 model components...")
|
||||||
|
|
||||||
|
if hasattr(self, 'parallel_generator') and self.parallel_generator is not None:
|
||||||
|
try:
|
||||||
|
self.parallel_generator.stop()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error stopping parallel generator: {e}")
|
||||||
|
self.parallel_generator = None
|
||||||
|
|
||||||
if hasattr(self, 'vision_model') and self.vision_model is not None:
|
if hasattr(self, 'vision_model') and self.vision_model is not None:
|
||||||
try:
|
try:
|
||||||
del self.vision_model
|
del self.vision_model
|
||||||
|
|
|
||||||
|
|
@ -84,6 +84,12 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
|
||||||
self.ex_model.load(**load_params)
|
self.ex_model.load(**load_params)
|
||||||
self.past_seq = None
|
self.past_seq = None
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
self.layer_type = layer_type
|
||||||
|
self.cache_kwargs = cache_kwargs
|
||||||
|
|
||||||
|
if shared.args.cfg_cache:
|
||||||
|
self.ex_cache_negative = Cache(self.ex_model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs)
|
||||||
|
self.past_seq_negative = None
|
||||||
|
|
||||||
def _validate_model_class(self):
|
def _validate_model_class(self):
|
||||||
pass
|
pass
|
||||||
|
|
@ -126,7 +132,7 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
|
||||||
reset = True
|
reset = True
|
||||||
|
|
||||||
# Maximum number of tokens to process in a single forward pass
|
# Maximum number of tokens to process in a single forward pass
|
||||||
max_chunk_size = 256
|
max_chunk_size = 2048
|
||||||
|
|
||||||
# Make the forward call
|
# Make the forward call
|
||||||
if labels is None:
|
if labels is None:
|
||||||
|
|
@ -147,17 +153,16 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
|
||||||
# Process tokens from longest_prefix to second-to-last token
|
# Process tokens from longest_prefix to second-to-last token
|
||||||
tokens_to_process = seq_tensor[longest_prefix:-1]
|
tokens_to_process = seq_tensor[longest_prefix:-1]
|
||||||
|
|
||||||
# Process in chunks if the number of tokens is large
|
# Use prefill() to fill the cache without computing logits
|
||||||
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
|
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
|
||||||
chunk = tokens_to_process[i:i + max_chunk_size]
|
chunk = tokens_to_process[i:i + max_chunk_size]
|
||||||
self.ex_model.forward(
|
self.ex_model.prefill(
|
||||||
input_ids=chunk.view(1, -1),
|
input_ids=chunk.view(1, -1),
|
||||||
params={
|
params={
|
||||||
"attn_mode": "flash_attn",
|
"attn_mode": "flash_attn",
|
||||||
"cache": ex_cache,
|
"cache": ex_cache,
|
||||||
"past_len": longest_prefix + i,
|
"past_len": longest_prefix + i,
|
||||||
"batch_shape": (1, self.max_tokens),
|
"batch_shape": (1, self.max_tokens),
|
||||||
"reconstruct": False # Force memory-efficient path
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -168,18 +173,17 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
|
||||||
# Process all tokens except the last one
|
# Process all tokens except the last one
|
||||||
tokens_to_process = seq_tensor[:-1]
|
tokens_to_process = seq_tensor[:-1]
|
||||||
|
|
||||||
# Process in chunks if the number of tokens is large
|
# Use prefill() to fill the cache without computing logits
|
||||||
current_len = 0
|
current_len = 0
|
||||||
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
|
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
|
||||||
chunk = tokens_to_process[i:i + max_chunk_size]
|
chunk = tokens_to_process[i:i + max_chunk_size]
|
||||||
self.ex_model.forward(
|
self.ex_model.prefill(
|
||||||
input_ids=chunk.view(1, -1),
|
input_ids=chunk.view(1, -1),
|
||||||
params={
|
params={
|
||||||
"attn_mode": "flash_attn",
|
"attn_mode": "flash_attn",
|
||||||
"cache": ex_cache,
|
"cache": ex_cache,
|
||||||
"past_len": current_len,
|
"past_len": current_len,
|
||||||
"batch_shape": (1, self.max_tokens),
|
"batch_shape": (1, self.max_tokens),
|
||||||
"reconstruct": False # Force memory-efficient path
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
current_len += chunk.shape[0]
|
current_len += chunk.shape[0]
|
||||||
|
|
@ -194,24 +198,26 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
|
||||||
"cache": ex_cache,
|
"cache": ex_cache,
|
||||||
"past_len": current_len,
|
"past_len": current_len,
|
||||||
"batch_shape": (1, self.max_tokens),
|
"batch_shape": (1, self.max_tokens),
|
||||||
"reconstruct": False # Force memory-efficient path
|
|
||||||
}
|
}
|
||||||
).to(input_ids.device).float()
|
).to(input_ids.device).float()
|
||||||
else:
|
else:
|
||||||
# When processing with labels, handle as a complete sequence
|
# Labels path: use cache for cross-chunk attention.
|
||||||
# Process in chunks if the number of tokens is large
|
|
||||||
tokens_to_process = seq_tensor
|
tokens_to_process = seq_tensor
|
||||||
all_logits = None
|
all_logits = None
|
||||||
|
current_len = 0
|
||||||
|
|
||||||
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
|
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
|
||||||
chunk = tokens_to_process[i:i + max_chunk_size]
|
chunk = tokens_to_process[i:i + max_chunk_size]
|
||||||
chunk_logits = self.ex_model.forward(
|
chunk_logits = self.ex_model.forward(
|
||||||
input_ids=chunk.view(1, -1),
|
input_ids=chunk.view(1, -1),
|
||||||
params={
|
params={
|
||||||
"attn_mode": "flash_attn_nc", # No caching for training
|
"attn_mode": "flash_attn",
|
||||||
"reconstruct": False # Force memory-efficient path
|
"cache": ex_cache,
|
||||||
|
"past_len": current_len,
|
||||||
|
"batch_shape": (1, self.max_tokens),
|
||||||
}
|
}
|
||||||
).float()
|
).float()
|
||||||
|
current_len += chunk.shape[0]
|
||||||
|
|
||||||
if all_logits is None:
|
if all_logits is None:
|
||||||
all_logits = chunk_logits
|
all_logits = chunk_logits
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
import importlib
|
import importlib
|
||||||
|
import importlib.util
|
||||||
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
|
|
@ -38,9 +38,15 @@ def load_extensions():
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Prefer user extension, fall back to system extension
|
# Prefer user extension, fall back to system extension
|
||||||
user_script_path = Path(f'user_data/extensions/{name}/script.py')
|
user_script_path = shared.user_data_dir / 'extensions' / name / 'script.py'
|
||||||
if user_script_path.exists():
|
if user_script_path.exists():
|
||||||
extension = importlib.import_module(f"user_data.extensions.{name}.script")
|
spec = importlib.util.spec_from_file_location(
|
||||||
|
f"user_ext_{name}",
|
||||||
|
str(user_script_path)
|
||||||
|
)
|
||||||
|
extension = importlib.util.module_from_spec(spec)
|
||||||
|
sys.modules[spec.name] = extension
|
||||||
|
spec.loader.exec_module(extension)
|
||||||
else:
|
else:
|
||||||
extension = importlib.import_module(f"extensions.{name}.script")
|
extension = importlib.import_module(f"extensions.{name}.script")
|
||||||
|
|
||||||
|
|
@ -53,7 +59,7 @@ def load_extensions():
|
||||||
state[name] = [True, i, extension] # Store extension object
|
state[name] = [True, i, extension] # Store extension object
|
||||||
|
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
extension_location = Path('user_data/extensions') / name if user_script_path.exists() else Path('extensions') / name
|
extension_location = shared.user_data_dir / 'extensions' / name if user_script_path.exists() else Path('extensions') / name
|
||||||
windows_path = str(extension_location).replace('/', '\\')
|
windows_path = str(extension_location).replace('/', '\\')
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\n"
|
f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\n"
|
||||||
|
|
@ -206,6 +212,7 @@ def _apply_custom_js():
|
||||||
|
|
||||||
|
|
||||||
def create_extensions_block():
|
def create_extensions_block():
|
||||||
|
import gradio as gr
|
||||||
to_display = []
|
to_display = []
|
||||||
for extension, name in iterator():
|
for extension, name in iterator():
|
||||||
if hasattr(extension, "ui") and not (hasattr(extension, 'params') and extension.params.get('is_tab', False)):
|
if hasattr(extension, "ui") and not (hasattr(extension, 'params') and extension.params.get('is_tab', False)):
|
||||||
|
|
@ -220,6 +227,7 @@ def create_extensions_block():
|
||||||
|
|
||||||
|
|
||||||
def create_extensions_tabs():
|
def create_extensions_tabs():
|
||||||
|
import gradio as gr
|
||||||
for extension, name in iterator():
|
for extension, name in iterator():
|
||||||
if hasattr(extension, "ui") and (hasattr(extension, 'params') and extension.params.get('is_tab', False)):
|
if hasattr(extension, "ui") and (hasattr(extension, 'params') and extension.params.get('is_tab', False)):
|
||||||
display_name = getattr(extension, 'params', {}).get('display_name', name)
|
display_name = getattr(extension, 'params', {}).get('display_name', name)
|
||||||
|
|
|
||||||
|
|
@ -1,97 +0,0 @@
|
||||||
'''
|
|
||||||
Most of the code here was adapted from:
|
|
||||||
https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14184
|
|
||||||
'''
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
import warnings
|
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
import gradio.routes
|
|
||||||
import gradio.utils
|
|
||||||
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
|
||||||
|
|
||||||
from modules import shared
|
|
||||||
|
|
||||||
orig_create_app = gradio.routes.App.create_app
|
|
||||||
|
|
||||||
|
|
||||||
# Be strict about only approving access to localhost by default
|
|
||||||
def create_app_with_trustedhost(*args, **kwargs):
|
|
||||||
app = orig_create_app(*args, **kwargs)
|
|
||||||
|
|
||||||
if not (shared.args.listen or shared.args.share):
|
|
||||||
app.add_middleware(
|
|
||||||
TrustedHostMiddleware,
|
|
||||||
allowed_hosts=["localhost", "127.0.0.1"]
|
|
||||||
)
|
|
||||||
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
gradio.routes.App.create_app = create_app_with_trustedhost
|
|
||||||
gradio.utils.launch_counter = lambda: None
|
|
||||||
|
|
||||||
|
|
||||||
class GradioDeprecationWarning(DeprecationWarning):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def repair(grclass):
|
|
||||||
if not getattr(grclass, 'EVENTS', None):
|
|
||||||
return
|
|
||||||
|
|
||||||
@wraps(grclass.__init__)
|
|
||||||
def __repaired_init__(self, *args, tooltip=None, source=None, original=grclass.__init__, **kwargs):
|
|
||||||
if source:
|
|
||||||
kwargs["sources"] = [source]
|
|
||||||
|
|
||||||
allowed_kwargs = inspect.signature(original).parameters
|
|
||||||
fixed_kwargs = {}
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
if k in allowed_kwargs:
|
|
||||||
fixed_kwargs[k] = v
|
|
||||||
else:
|
|
||||||
warnings.warn(f"unexpected argument for {grclass.__name__}: {k}", GradioDeprecationWarning, stacklevel=2)
|
|
||||||
|
|
||||||
original(self, *args, **fixed_kwargs)
|
|
||||||
|
|
||||||
self.webui_tooltip = tooltip
|
|
||||||
|
|
||||||
for event in self.EVENTS:
|
|
||||||
replaced_event = getattr(self, str(event))
|
|
||||||
|
|
||||||
def fun(*xargs, _js=None, replaced_event=replaced_event, **xkwargs):
|
|
||||||
if _js:
|
|
||||||
xkwargs['js'] = _js
|
|
||||||
|
|
||||||
return replaced_event(*xargs, **xkwargs)
|
|
||||||
|
|
||||||
setattr(self, str(event), fun)
|
|
||||||
|
|
||||||
grclass.__init__ = __repaired_init__
|
|
||||||
grclass.update = gr.update
|
|
||||||
|
|
||||||
|
|
||||||
for component in set(gr.components.__all__ + gr.layouts.__all__):
|
|
||||||
repair(getattr(gr, component, None))
|
|
||||||
|
|
||||||
|
|
||||||
class Dependency(gr.events.Dependency):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def then(*xargs, _js=None, **xkwargs):
|
|
||||||
if _js:
|
|
||||||
xkwargs['js'] = _js
|
|
||||||
|
|
||||||
return original_then(*xargs, **xkwargs)
|
|
||||||
|
|
||||||
original_then = self.then
|
|
||||||
self.then = then
|
|
||||||
|
|
||||||
|
|
||||||
gr.events.Dependency = Dependency
|
|
||||||
|
|
||||||
gr.Box = gr.Group
|
|
||||||
|
|
@ -10,6 +10,7 @@ import markdown
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
from modules.reasoning import extract_reasoning
|
||||||
from modules.sane_markdown_lists import SaneListExtension
|
from modules.sane_markdown_lists import SaneListExtension
|
||||||
from modules.utils import get_available_chat_styles
|
from modules.utils import get_available_chat_styles
|
||||||
|
|
||||||
|
|
@ -109,94 +110,40 @@ def replace_blockquote(m):
|
||||||
|
|
||||||
|
|
||||||
def extract_thinking_block(string):
|
def extract_thinking_block(string):
|
||||||
"""Extract thinking blocks from the beginning of a string."""
|
"""Extract thinking blocks from the beginning of an HTML-escaped string."""
|
||||||
if not string:
|
return extract_reasoning(string, html_escaped=True)
|
||||||
return None, string
|
|
||||||
|
|
||||||
THINK_START_TAG = "<think>"
|
|
||||||
THINK_END_TAG = "</think>"
|
|
||||||
|
|
||||||
# Look for think tag first
|
|
||||||
start_pos = string.find(THINK_START_TAG)
|
|
||||||
end_pos = string.find(THINK_END_TAG)
|
|
||||||
|
|
||||||
# If think tags found, use existing logic
|
|
||||||
if start_pos != -1 or end_pos != -1:
|
|
||||||
# handle missing start or end tags
|
|
||||||
if start_pos == -1:
|
|
||||||
thought_start = 0
|
|
||||||
else:
|
|
||||||
thought_start = start_pos + len(THINK_START_TAG)
|
|
||||||
if end_pos == -1:
|
|
||||||
thought_end = len(string)
|
|
||||||
content_start = len(string)
|
|
||||||
else:
|
|
||||||
thought_end = end_pos
|
|
||||||
content_start = end_pos + len(THINK_END_TAG)
|
|
||||||
thinking_content = string[thought_start:thought_end]
|
|
||||||
remaining_content = string[content_start:]
|
|
||||||
return thinking_content, remaining_content
|
|
||||||
|
|
||||||
# If think tags not found, try GPT-OSS alternative format
|
|
||||||
ALT_START = "<|channel|>analysis<|message|>"
|
|
||||||
ALT_END = "<|end|>"
|
|
||||||
ALT_CONTENT_START = "<|start|>assistant<|channel|>final<|message|>"
|
|
||||||
|
|
||||||
alt_start_pos = string.find(ALT_START)
|
|
||||||
alt_end_pos = string.find(ALT_END)
|
|
||||||
alt_content_pos = string.find(ALT_CONTENT_START)
|
|
||||||
|
|
||||||
if alt_start_pos != -1 or alt_end_pos != -1:
|
|
||||||
if alt_start_pos == -1:
|
|
||||||
thought_start = 0
|
|
||||||
else:
|
|
||||||
thought_start = alt_start_pos + len(ALT_START)
|
|
||||||
|
|
||||||
# If no explicit end tag but content start exists, use content start as end
|
|
||||||
if alt_end_pos == -1:
|
|
||||||
if alt_content_pos != -1:
|
|
||||||
thought_end = alt_content_pos
|
|
||||||
content_start = alt_content_pos + len(ALT_CONTENT_START)
|
|
||||||
else:
|
|
||||||
thought_end = len(string)
|
|
||||||
content_start = len(string)
|
|
||||||
else:
|
|
||||||
thought_end = alt_end_pos
|
|
||||||
content_start = alt_content_pos + len(ALT_CONTENT_START) if alt_content_pos != -1 else alt_end_pos + len(ALT_END)
|
|
||||||
|
|
||||||
thinking_content = string[thought_start:thought_end]
|
|
||||||
remaining_content = string[content_start:]
|
|
||||||
return thinking_content, remaining_content
|
|
||||||
|
|
||||||
# Try seed:think format
|
|
||||||
SEED_START = "<seed:think>"
|
|
||||||
SEED_END = "</seed:think>"
|
|
||||||
|
|
||||||
seed_start_pos = string.find(SEED_START)
|
|
||||||
seed_end_pos = string.find(SEED_END)
|
|
||||||
|
|
||||||
if seed_start_pos != -1 or seed_end_pos != -1:
|
|
||||||
if seed_start_pos == -1:
|
|
||||||
thought_start = 0
|
|
||||||
else:
|
|
||||||
thought_start = seed_start_pos + len(SEED_START)
|
|
||||||
|
|
||||||
if seed_end_pos == -1:
|
|
||||||
thought_end = len(string)
|
|
||||||
content_start = len(string)
|
|
||||||
else:
|
|
||||||
thought_end = seed_end_pos
|
|
||||||
content_start = seed_end_pos + len(SEED_END)
|
|
||||||
|
|
||||||
thinking_content = string[thought_start:thought_end]
|
|
||||||
remaining_content = string[content_start:]
|
|
||||||
return thinking_content, remaining_content
|
|
||||||
|
|
||||||
# Return if no format is found
|
|
||||||
return None, string
|
|
||||||
|
|
||||||
|
|
||||||
def build_thinking_block(thinking_content, message_id, has_remaining_content):
|
|
||||||
|
def build_tool_call_block(header, body, message_id, index):
|
||||||
|
"""Build HTML for a tool call accordion block."""
|
||||||
|
block_id = f"tool-call-{message_id}-{index}"
|
||||||
|
|
||||||
|
if body == '...':
|
||||||
|
# Pending placeholder — no expandable body, just title with ellipsis
|
||||||
|
return f'''
|
||||||
|
<details class="thinking-block" data-block-id="{block_id}">
|
||||||
|
<summary class="thinking-header">
|
||||||
|
{tool_svg_small}
|
||||||
|
<span class="thinking-title">{html.escape(header)} ...</span>
|
||||||
|
</summary>
|
||||||
|
</details>
|
||||||
|
'''
|
||||||
|
|
||||||
|
# Build a plain <pre> directly to avoid highlight.js auto-detection
|
||||||
|
escaped_body = html.escape(body)
|
||||||
|
return f'''
|
||||||
|
<details class="thinking-block" data-block-id="{block_id}">
|
||||||
|
<summary class="thinking-header">
|
||||||
|
{tool_svg_small}
|
||||||
|
<span class="thinking-title">{html.escape(header)}</span>
|
||||||
|
</summary>
|
||||||
|
<div class="thinking-content pretty_scrollbar"><pre><code class="nohighlight">{escaped_body}</code></pre></div>
|
||||||
|
</details>
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
def build_thinking_block(thinking_content, message_id, has_remaining_content, thinking_index=0):
|
||||||
"""Build HTML for a thinking block."""
|
"""Build HTML for a thinking block."""
|
||||||
if thinking_content is None:
|
if thinking_content is None:
|
||||||
return None
|
return None
|
||||||
|
|
@ -205,7 +152,7 @@ def build_thinking_block(thinking_content, message_id, has_remaining_content):
|
||||||
thinking_html = process_markdown_content(thinking_content)
|
thinking_html = process_markdown_content(thinking_content)
|
||||||
|
|
||||||
# Generate unique ID for the thinking block
|
# Generate unique ID for the thinking block
|
||||||
block_id = f"thinking-{message_id}-0"
|
block_id = f"thinking-{message_id}-{thinking_index}"
|
||||||
|
|
||||||
# Check if thinking is complete or still in progress
|
# Check if thinking is complete or still in progress
|
||||||
is_streaming = not has_remaining_content
|
is_streaming = not has_remaining_content
|
||||||
|
|
@ -238,23 +185,27 @@ def process_markdown_content(string):
|
||||||
if not string:
|
if not string:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Define a unique placeholder for LaTeX asterisks
|
# Define unique placeholders for LaTeX asterisks and underscores
|
||||||
LATEX_ASTERISK_PLACEHOLDER = "LATEXASTERISKPLACEHOLDER"
|
LATEX_ASTERISK_PLACEHOLDER = "LATEXASTERISKPLACEHOLDER"
|
||||||
|
LATEX_UNDERSCORE_PLACEHOLDER = "LATEXUNDERSCOREPLACEHOLDER"
|
||||||
|
|
||||||
def protect_asterisks_in_latex(match):
|
def protect_asterisks_underscores_in_latex(match):
|
||||||
"""A replacer function for re.sub to protect asterisks in multiple LaTeX formats."""
|
"""A replacer function for re.sub to protect asterisks and underscores in multiple LaTeX formats."""
|
||||||
# Check which delimiter group was captured
|
# Check which delimiter group was captured
|
||||||
if match.group(1) is not None: # Content from $$...$$
|
if match.group(1) is not None: # Content from $$...$$
|
||||||
content = match.group(1)
|
content = match.group(1)
|
||||||
modified_content = content.replace('*', LATEX_ASTERISK_PLACEHOLDER)
|
modified_content = content.replace('*', LATEX_ASTERISK_PLACEHOLDER)
|
||||||
return f'$${modified_content}$$'
|
modified_content = modified_content.replace('_', LATEX_UNDERSCORE_PLACEHOLDER)
|
||||||
|
return f'{modified_content}'
|
||||||
elif match.group(2) is not None: # Content from \[...\]
|
elif match.group(2) is not None: # Content from \[...\]
|
||||||
content = match.group(2)
|
content = match.group(2)
|
||||||
modified_content = content.replace('*', LATEX_ASTERISK_PLACEHOLDER)
|
modified_content = content.replace('*', LATEX_ASTERISK_PLACEHOLDER)
|
||||||
|
modified_content = modified_content.replace('_', LATEX_UNDERSCORE_PLACEHOLDER)
|
||||||
return f'\\[{modified_content}\\]'
|
return f'\\[{modified_content}\\]'
|
||||||
elif match.group(3) is not None: # Content from \(...\)
|
elif match.group(3) is not None: # Content from \(...\)
|
||||||
content = match.group(3)
|
content = match.group(3)
|
||||||
modified_content = content.replace('*', LATEX_ASTERISK_PLACEHOLDER)
|
modified_content = content.replace('*', LATEX_ASTERISK_PLACEHOLDER)
|
||||||
|
modified_content = modified_content.replace('_', LATEX_UNDERSCORE_PLACEHOLDER)
|
||||||
return f'\\({modified_content}\\)'
|
return f'\\({modified_content}\\)'
|
||||||
|
|
||||||
return match.group(0) # Fallback
|
return match.group(0) # Fallback
|
||||||
|
|
@ -288,9 +239,10 @@ def process_markdown_content(string):
|
||||||
string = string.replace('\\end{equation*}', '$$')
|
string = string.replace('\\end{equation*}', '$$')
|
||||||
string = re.sub(r"(.)```", r"\1\n```", string)
|
string = re.sub(r"(.)```", r"\1\n```", string)
|
||||||
|
|
||||||
# Protect asterisks within all LaTeX blocks before markdown conversion
|
# Protect asterisks and underscores within all LaTeX blocks before markdown conversion
|
||||||
latex_pattern = re.compile(r'\$\$(.*?)\$\$|\\\[(.*?)\\\]|\\\((.*?)\\\)', re.DOTALL)
|
latex_pattern = re.compile(r'((?:^|[\r\n\s])\$\$[^`]*?\$\$)|\\\[(.*?)\\\]|\\\((.*?)\\\)',
|
||||||
string = latex_pattern.sub(protect_asterisks_in_latex, string)
|
re.DOTALL)
|
||||||
|
string = latex_pattern.sub(protect_asterisks_underscores_in_latex, string)
|
||||||
|
|
||||||
result = ''
|
result = ''
|
||||||
is_code = False
|
is_code = False
|
||||||
|
|
@ -302,11 +254,11 @@ def process_markdown_content(string):
|
||||||
|
|
||||||
if stripped_line.startswith('```'):
|
if stripped_line.startswith('```'):
|
||||||
is_code = not is_code
|
is_code = not is_code
|
||||||
elif stripped_line.startswith('$$'):
|
elif stripped_line.startswith('$$') and (stripped_line == "$$" or not stripped_line.endswith('$$')):
|
||||||
is_latex = not is_latex
|
is_latex = not is_latex
|
||||||
elif stripped_line.endswith('$$'):
|
elif stripped_line.endswith('$$'):
|
||||||
is_latex = False
|
is_latex = False
|
||||||
elif stripped_line.startswith('\\\\['):
|
elif stripped_line.startswith('\\\\[') and not stripped_line.endswith('\\\\]'):
|
||||||
is_latex = True
|
is_latex = True
|
||||||
elif stripped_line.startswith('\\\\]'):
|
elif stripped_line.startswith('\\\\]'):
|
||||||
is_latex = False
|
is_latex = False
|
||||||
|
|
@ -351,8 +303,9 @@ def process_markdown_content(string):
|
||||||
# Convert to HTML using markdown
|
# Convert to HTML using markdown
|
||||||
html_output = markdown.markdown(result, extensions=['fenced_code', 'tables', SaneListExtension()])
|
html_output = markdown.markdown(result, extensions=['fenced_code', 'tables', SaneListExtension()])
|
||||||
|
|
||||||
# Restore the LaTeX asterisks after markdown conversion
|
# Restore the LaTeX asterisks and underscores after markdown conversion
|
||||||
html_output = html_output.replace(LATEX_ASTERISK_PLACEHOLDER, '*')
|
html_output = html_output.replace(LATEX_ASTERISK_PLACEHOLDER, '*')
|
||||||
|
html_output = html_output.replace(LATEX_UNDERSCORE_PLACEHOLDER, '_')
|
||||||
|
|
||||||
# Remove extra newlines before </code>
|
# Remove extra newlines before </code>
|
||||||
html_output = re.sub(r'\s*</code>', '</code>', html_output)
|
html_output = re.sub(r'\s*</code>', '</code>', html_output)
|
||||||
|
|
@ -364,6 +317,9 @@ def process_markdown_content(string):
|
||||||
# Unescape backslashes
|
# Unescape backslashes
|
||||||
html_output = html_output.replace('\\\\', '\\')
|
html_output = html_output.replace('\\\\', '\\')
|
||||||
|
|
||||||
|
# Wrap tables in a scrollable div
|
||||||
|
html_output = html_output.replace('<table>', '<div class="table-wrapper pretty_scrollbar"><table>').replace('</table>', '</table></div>')
|
||||||
|
|
||||||
return html_output
|
return html_output
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -380,25 +336,67 @@ def convert_to_markdown(string, message_id=None):
|
||||||
if message_id is None:
|
if message_id is None:
|
||||||
message_id = "unknown"
|
message_id = "unknown"
|
||||||
|
|
||||||
# Extract different components from the string
|
# Find tool call blocks by position, then process the text segments
|
||||||
|
# between them using extract_thinking_block (which supports all
|
||||||
|
# THINKING_FORMATS, including end-only variants like Qwen's).
|
||||||
|
tool_call_pattern = re.compile(r'<tool_call>(.*?)\n(.*?)\n</tool_call>', re.DOTALL)
|
||||||
|
tool_calls = list(tool_call_pattern.finditer(string))
|
||||||
|
|
||||||
|
if not tool_calls:
|
||||||
|
# No tool calls — use original single-pass extraction
|
||||||
thinking_content, remaining_content = extract_thinking_block(string)
|
thinking_content, remaining_content = extract_thinking_block(string)
|
||||||
|
|
||||||
# Build individual HTML blocks
|
|
||||||
blocks = []
|
blocks = []
|
||||||
|
|
||||||
# Add thinking block if present
|
|
||||||
thinking_html = build_thinking_block(thinking_content, message_id, bool(remaining_content))
|
thinking_html = build_thinking_block(thinking_content, message_id, bool(remaining_content))
|
||||||
if thinking_html:
|
if thinking_html:
|
||||||
blocks.append(thinking_html)
|
blocks.append(thinking_html)
|
||||||
|
|
||||||
# Add main content block
|
|
||||||
main_html = build_main_content_block(remaining_content)
|
main_html = build_main_content_block(remaining_content)
|
||||||
if main_html:
|
if main_html:
|
||||||
blocks.append(main_html)
|
blocks.append(main_html)
|
||||||
|
|
||||||
# Assemble all blocks into final HTML
|
|
||||||
return ''.join(blocks)
|
return ''.join(blocks)
|
||||||
|
|
||||||
|
# Split string into text segments around tool_call blocks and
|
||||||
|
# run extract_thinking_block on each segment for full format support.
|
||||||
|
html_parts = []
|
||||||
|
last_end = 0
|
||||||
|
tool_idx = 0
|
||||||
|
think_idx = 0
|
||||||
|
|
||||||
|
def process_text_segment(text, is_last_segment):
|
||||||
|
"""Process a text segment between tool_call blocks for thinking content."""
|
||||||
|
nonlocal think_idx
|
||||||
|
if not text.strip():
|
||||||
|
return
|
||||||
|
|
||||||
|
while text.strip():
|
||||||
|
thinking_content, remaining = extract_thinking_block(text)
|
||||||
|
if thinking_content is None:
|
||||||
|
break
|
||||||
|
has_remaining = bool(remaining.strip()) or not is_last_segment
|
||||||
|
html_parts.append(build_thinking_block(thinking_content, message_id, has_remaining, think_idx))
|
||||||
|
think_idx += 1
|
||||||
|
text = remaining
|
||||||
|
|
||||||
|
if text.strip():
|
||||||
|
html_parts.append(process_markdown_content(text))
|
||||||
|
|
||||||
|
for tc in tool_calls:
|
||||||
|
# Process text before this tool_call
|
||||||
|
process_text_segment(string[last_end:tc.start()], is_last_segment=False)
|
||||||
|
|
||||||
|
# Add tool call accordion
|
||||||
|
header = tc.group(1).strip()
|
||||||
|
body = tc.group(2).strip()
|
||||||
|
html_parts.append(build_tool_call_block(header, body, message_id, tool_idx))
|
||||||
|
tool_idx += 1
|
||||||
|
last_end = tc.end()
|
||||||
|
|
||||||
|
# Process text after the last tool_call
|
||||||
|
process_text_segment(string[last_end:], is_last_segment=True)
|
||||||
|
|
||||||
|
return ''.join(html_parts)
|
||||||
|
|
||||||
|
|
||||||
def convert_to_markdown_wrapped(string, message_id=None, use_cache=True):
|
def convert_to_markdown_wrapped(string, message_id=None, use_cache=True):
|
||||||
'''
|
'''
|
||||||
|
|
@ -455,6 +453,7 @@ branch_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24
|
||||||
edit_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="tabler-icon tabler-icon-pencil"><path d="M4 20h4l10.5 -10.5a2.828 2.828 0 1 0 -4 -4l-10.5 10.5v4"></path><path d="M13.5 6.5l4 4"></path></svg>'''
|
edit_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="tabler-icon tabler-icon-pencil"><path d="M4 20h4l10.5 -10.5a2.828 2.828 0 1 0 -4 -4l-10.5 10.5v4"></path><path d="M13.5 6.5l4 4"></path></svg>'''
|
||||||
info_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="thinking-icon tabler-icon tabler-icon-info-circle"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M12 2a10 10 0 0 1 0 20a10 10 0 0 1 0 -20z" /><path d="M12 16v-4" /><path d="M12 8h.01" /></svg>'''
|
info_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="thinking-icon tabler-icon tabler-icon-info-circle"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M12 2a10 10 0 0 1 0 20a10 10 0 0 1 0 -20z" /><path d="M12 16v-4" /><path d="M12 8h.01" /></svg>'''
|
||||||
info_svg_small = '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="thinking-icon tabler-icon tabler-icon-info-circle"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M12 2a10 10 0 0 1 0 20a10 10 0 0 1 0 -20z" /><path d="M12 16v-4" /><path d="M12 8h.01" /></svg>'''
|
info_svg_small = '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="thinking-icon tabler-icon tabler-icon-info-circle"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M12 2a10 10 0 0 1 0 20a10 10 0 0 1 0 -20z" /><path d="M12 16v-4" /><path d="M12 8h.01" /></svg>'''
|
||||||
|
tool_svg_small = '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="thinking-icon tabler-icon tabler-icon-tool"><path stroke="none" d="M0 0h24v24H0z" fill="none"/><path d="M7 10h3v-3l-3.5 -3.5a6 6 0 0 1 8 8l6 6a2 2 0 0 1 -3 3l-6 -6a6 6 0 0 1 -8 -8l3.5 3.5" /></svg>'''
|
||||||
attachment_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M21.44 11.05l-9.19 9.19a6 6 0 0 1-8.48-8.48l9.19-9.19a4 4 0 0 1 5.66 5.66l-9.2 9.19a2 2 0 0 1-2.83-2.83l8.49-8.48"></path></svg>'''
|
attachment_svg = '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M21.44 11.05l-9.19 9.19a6 6 0 0 1-8.48-8.48l9.19-9.19a4 4 0 0 1 5.66 5.66l-9.2 9.19a2 2 0 0 1-2.83-2.83l8.49-8.48"></path></svg>'''
|
||||||
|
|
||||||
copy_button = f'<button class="footer-button footer-copy-button" title="Copy" onclick="copyToClipboard(this)">{copy_svg}</button>'
|
copy_button = f'<button class="footer-button footer-copy-button" title="Copy" onclick="copyToClipboard(this)">{copy_svg}</button>'
|
||||||
|
|
@ -648,10 +647,10 @@ def generate_instruct_html(history, last_message_only=False):
|
||||||
|
|
||||||
def get_character_image_with_cache_buster():
|
def get_character_image_with_cache_buster():
|
||||||
"""Get character image URL with cache busting based on file modification time"""
|
"""Get character image URL with cache busting based on file modification time"""
|
||||||
cache_path = Path("user_data/cache/pfp_character_thumb.png")
|
cache_path = shared.user_data_dir / "cache" / "pfp_character_thumb.png"
|
||||||
if cache_path.exists():
|
if cache_path.exists():
|
||||||
mtime = int(cache_path.stat().st_mtime)
|
mtime = int(cache_path.stat().st_mtime)
|
||||||
return f'<img src="file/user_data/cache/pfp_character_thumb.png?{mtime}" class="pfp_character">'
|
return f'<img src="file/{shared.user_data_dir}/cache/pfp_character_thumb.png?{mtime}" class="pfp_character">'
|
||||||
|
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
@ -675,8 +674,8 @@ def generate_cai_chat_html(history, name1, name2, style, character, reset_cache=
|
||||||
|
|
||||||
# Get appropriate image
|
# Get appropriate image
|
||||||
if role == "user":
|
if role == "user":
|
||||||
img = (f'<img src="file/user_data/cache/pfp_me.png?{time.time() if reset_cache else ""}">'
|
img = (f'<img src="file/{shared.user_data_dir}/cache/pfp_me.png?{time.time() if reset_cache else ""}">'
|
||||||
if Path("user_data/cache/pfp_me.png").exists() else '')
|
if (shared.user_data_dir / "cache" / "pfp_me.png").exists() else '')
|
||||||
else:
|
else:
|
||||||
img = img_bot
|
img = img_bot
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -77,7 +77,18 @@ def process_message_content(content: Any) -> Tuple[str, List[Image.Image]]:
|
||||||
# Support external URLs
|
# Support external URLs
|
||||||
try:
|
try:
|
||||||
import requests
|
import requests
|
||||||
response = requests.get(image_url, timeout=10)
|
from urllib.parse import urljoin
|
||||||
|
from modules.web_search import _validate_url
|
||||||
|
_validate_url(image_url)
|
||||||
|
url = image_url
|
||||||
|
for _ in range(5):
|
||||||
|
response = requests.get(url, timeout=10, allow_redirects=False)
|
||||||
|
if response.is_redirect and 'Location' in response.headers:
|
||||||
|
url = urljoin(url, response.headers['Location'])
|
||||||
|
_validate_url(url)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
image_data = response.content
|
image_data = response.content
|
||||||
image = Image.open(io.BytesIO(image_data))
|
image = Image.open(io.BytesIO(image_data))
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ class LlamaServer:
|
||||||
self.process = None
|
self.process = None
|
||||||
self.session = requests.Session()
|
self.session = requests.Session()
|
||||||
self.vocabulary_size = None
|
self.vocabulary_size = None
|
||||||
|
self.n_ctx = None
|
||||||
self.bos_token = "<s>"
|
self.bos_token = "<s>"
|
||||||
self.last_prompt_token_count = 0
|
self.last_prompt_token_count = 0
|
||||||
|
|
||||||
|
|
@ -75,6 +76,8 @@ class LlamaServer:
|
||||||
"top_p": state["top_p"],
|
"top_p": state["top_p"],
|
||||||
"min_p": state["min_p"],
|
"min_p": state["min_p"],
|
||||||
"top_n_sigma": state["top_n_sigma"] if state["top_n_sigma"] > 0 else -1,
|
"top_n_sigma": state["top_n_sigma"] if state["top_n_sigma"] > 0 else -1,
|
||||||
|
"adaptive_target": state["adaptive_target"] if state["adaptive_target"] > 0 else -1,
|
||||||
|
"adaptive_decay": state["adaptive_decay"],
|
||||||
"typical_p": state["typical_p"],
|
"typical_p": state["typical_p"],
|
||||||
"repeat_penalty": state["repetition_penalty"],
|
"repeat_penalty": state["repetition_penalty"],
|
||||||
"repeat_last_n": state["repetition_penalty_range"],
|
"repeat_last_n": state["repetition_penalty_range"],
|
||||||
|
|
@ -119,15 +122,32 @@ class LlamaServer:
|
||||||
penalty_found = True
|
penalty_found = True
|
||||||
|
|
||||||
# Move temperature to the end if temperature_last is true and temperature exists in the list
|
# Move temperature to the end if temperature_last is true and temperature exists in the list
|
||||||
if state["temperature_last"] and "temperature" in samplers:
|
if state["temperature_last"] and "temperature" in filtered_samplers:
|
||||||
samplers.remove("temperature")
|
filtered_samplers.remove("temperature")
|
||||||
samplers.append("temperature")
|
filtered_samplers.append("temperature")
|
||||||
|
|
||||||
|
# adaptive-p replaces the default dist sampler; llama.cpp always
|
||||||
|
# places it at the end of the chain regardless of position, so we
|
||||||
|
# activate it based on the parameter value rather than sampler order.
|
||||||
|
if state.get("adaptive_target", 0) > 0:
|
||||||
|
filtered_samplers.append("adaptive_p")
|
||||||
|
|
||||||
payload["samplers"] = filtered_samplers
|
payload["samplers"] = filtered_samplers
|
||||||
|
|
||||||
|
logit_bias = []
|
||||||
if state['custom_token_bans']:
|
if state['custom_token_bans']:
|
||||||
to_ban = [[int(token_id), False] for token_id in state['custom_token_bans'].split(',')]
|
logit_bias.extend([[int(token_id.strip()), False] for token_id in state['custom_token_bans'].split(',') if token_id.strip()])
|
||||||
payload["logit_bias"] = to_ban
|
|
||||||
|
if state.get('logit_bias'):
|
||||||
|
for token_id_str, bias in state['logit_bias'].items():
|
||||||
|
logit_bias.append([int(token_id_str), bias])
|
||||||
|
|
||||||
|
if logit_bias:
|
||||||
|
payload["logit_bias"] = logit_bias
|
||||||
|
|
||||||
|
n_probs = state.get('logprobs', 0)
|
||||||
|
if n_probs and n_probs > 0:
|
||||||
|
payload["n_probs"] = n_probs
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
|
@ -200,16 +220,19 @@ class LlamaServer:
|
||||||
# Make the generation request
|
# Make the generation request
|
||||||
response = self.session.post(url, json=payload, stream=True)
|
response = self.session.post(url, json=payload, stream=True)
|
||||||
try:
|
try:
|
||||||
if response.status_code == 400 and response.json()["error"]["type"] == "exceed_context_size_error":
|
if response.status_code == 400 and response.json().get("error", {}).get("type") == "exceed_context_size_error":
|
||||||
logger.error("The request exceeds the available context size, try increasing it")
|
logger.error("The request exceeds the available context size, try increasing it")
|
||||||
|
return
|
||||||
else:
|
else:
|
||||||
response.raise_for_status() # Raise an exception for HTTP errors
|
response.raise_for_status() # Raise an exception for HTTP errors
|
||||||
|
|
||||||
full_text = ""
|
full_text = ""
|
||||||
|
self.last_completion_probabilities = []
|
||||||
|
|
||||||
# Process the streaming response
|
# Process the streaming response
|
||||||
|
stop_event = state.get('stop_event')
|
||||||
for line in response.iter_lines():
|
for line in response.iter_lines():
|
||||||
if shared.stop_everything:
|
if shared.stop_everything or (stop_event and stop_event.is_set()):
|
||||||
break
|
break
|
||||||
|
|
||||||
if not line:
|
if not line:
|
||||||
|
|
@ -230,6 +253,10 @@ class LlamaServer:
|
||||||
full_text += data['content']
|
full_text += data['content']
|
||||||
yield full_text
|
yield full_text
|
||||||
|
|
||||||
|
# Capture logprobs if present
|
||||||
|
if 'completion_probabilities' in data:
|
||||||
|
self.last_completion_probabilities.extend(data['completion_probabilities'])
|
||||||
|
|
||||||
# Check if generation is complete
|
# Check if generation is complete
|
||||||
if data.get('stop', False):
|
if data.get('stop', False):
|
||||||
break
|
break
|
||||||
|
|
@ -278,6 +305,8 @@ class LlamaServer:
|
||||||
return result["completion_probabilities"][0]["top_probs"]
|
return result["completion_probabilities"][0]["top_probs"]
|
||||||
else:
|
else:
|
||||||
return result["completion_probabilities"][0]["top_logprobs"]
|
return result["completion_probabilities"][0]["top_logprobs"]
|
||||||
|
|
||||||
|
time.sleep(0.05)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}")
|
raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}")
|
||||||
|
|
||||||
|
|
@ -292,16 +321,35 @@ class LlamaServer:
|
||||||
self.vocabulary_size = model_info["meta"]["n_vocab"]
|
self.vocabulary_size = model_info["meta"]["n_vocab"]
|
||||||
|
|
||||||
def _get_bos_token(self):
|
def _get_bos_token(self):
|
||||||
"""Get and store the model's BOS token."""
|
"""Get and store the model's BOS token and context size."""
|
||||||
url = f"http://127.0.0.1:{self.port}/props"
|
url = f"http://127.0.0.1:{self.port}/props"
|
||||||
response = self.session.get(url).json()
|
response = self.session.get(url).json()
|
||||||
if "bos_token" in response:
|
if "bos_token" in response:
|
||||||
self.bos_token = response["bos_token"]
|
self.bos_token = response["bos_token"]
|
||||||
|
|
||||||
def _find_available_port(self):
|
# Get actual n_ctx from the server (important when --fit auto-selects it)
|
||||||
"""Find an available port by letting the OS assign one."""
|
n_ctx = response.get("default_generation_settings", {}).get("n_ctx")
|
||||||
|
if n_ctx:
|
||||||
|
self.n_ctx = n_ctx
|
||||||
|
|
||||||
|
def _is_port_available(self, port):
|
||||||
|
"""Check if a port is available for use."""
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
s.bind(('', 0)) # Bind to port 0 to get an available port
|
try:
|
||||||
|
s.bind(('', port))
|
||||||
|
return True
|
||||||
|
except OSError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _find_available_port(self):
|
||||||
|
"""Find an available port, preferring main port + 5."""
|
||||||
|
preferred_port = shared.args.api_port + 5
|
||||||
|
if self._is_port_available(preferred_port):
|
||||||
|
return preferred_port
|
||||||
|
|
||||||
|
# Fall back to OS-assigned random port
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(('', 0))
|
||||||
return s.getsockname()[1]
|
return s.getsockname()[1]
|
||||||
|
|
||||||
def _start_server(self):
|
def _start_server(self):
|
||||||
|
|
@ -314,8 +362,6 @@ class LlamaServer:
|
||||||
cmd = [
|
cmd = [
|
||||||
self.server_path,
|
self.server_path,
|
||||||
"--model", self.model_path,
|
"--model", self.model_path,
|
||||||
"--ctx-size", str(shared.args.ctx_size),
|
|
||||||
"--gpu-layers", str(shared.args.gpu_layers),
|
|
||||||
"--batch-size", str(shared.args.batch_size),
|
"--batch-size", str(shared.args.batch_size),
|
||||||
"--ubatch-size", str(shared.args.ubatch_size),
|
"--ubatch-size", str(shared.args.ubatch_size),
|
||||||
"--port", str(self.port),
|
"--port", str(self.port),
|
||||||
|
|
@ -323,6 +369,19 @@ class LlamaServer:
|
||||||
"--flash-attn", "on",
|
"--flash-attn", "on",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if shared.args.ctx_size > 0:
|
||||||
|
cmd += ["--ctx-size", str(shared.args.ctx_size)]
|
||||||
|
elif shared.args.gpu_layers >= 0:
|
||||||
|
cmd += ["--ctx-size", "8192"]
|
||||||
|
|
||||||
|
if shared.args.gpu_layers >= 0:
|
||||||
|
cmd += ["--gpu-layers", str(shared.args.gpu_layers), "--fit", "off"]
|
||||||
|
else:
|
||||||
|
cmd += ["--fit", "on"]
|
||||||
|
cmd += ["--fit-ctx", "8192"]
|
||||||
|
if shared.args.fit_target:
|
||||||
|
cmd += ["--fit-target", shared.args.fit_target]
|
||||||
|
|
||||||
if shared.args.threads > 0:
|
if shared.args.threads > 0:
|
||||||
cmd += ["--threads", str(shared.args.threads)]
|
cmd += ["--threads", str(shared.args.threads)]
|
||||||
if shared.args.threads_batch > 0:
|
if shared.args.threads_batch > 0:
|
||||||
|
|
@ -345,14 +404,10 @@ class LlamaServer:
|
||||||
if shared.args.cache_type != "fp16" and shared.args.cache_type in llamacpp_valid_cache_types:
|
if shared.args.cache_type != "fp16" and shared.args.cache_type in llamacpp_valid_cache_types:
|
||||||
cmd += ["--cache-type-k", shared.args.cache_type, "--cache-type-v", shared.args.cache_type]
|
cmd += ["--cache-type-k", shared.args.cache_type, "--cache-type-v", shared.args.cache_type]
|
||||||
cache_type = shared.args.cache_type
|
cache_type = shared.args.cache_type
|
||||||
if shared.args.compress_pos_emb != 1:
|
|
||||||
cmd += ["--rope-freq-scale", str(1.0 / shared.args.compress_pos_emb)]
|
|
||||||
if shared.args.rope_freq_base > 0:
|
|
||||||
cmd += ["--rope-freq-base", str(shared.args.rope_freq_base)]
|
|
||||||
if shared.args.mmproj not in [None, 'None']:
|
if shared.args.mmproj not in [None, 'None']:
|
||||||
path = Path(shared.args.mmproj)
|
path = Path(shared.args.mmproj)
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
path = Path('user_data/mmproj') / shared.args.mmproj
|
path = shared.user_data_dir / 'mmproj' / shared.args.mmproj
|
||||||
|
|
||||||
if path.exists():
|
if path.exists():
|
||||||
cmd += ["--mmproj", str(path)]
|
cmd += ["--mmproj", str(path)]
|
||||||
|
|
@ -364,7 +419,7 @@ class LlamaServer:
|
||||||
else:
|
else:
|
||||||
model_file = sorted(path.glob('*.gguf'))[0]
|
model_file = sorted(path.glob('*.gguf'))[0]
|
||||||
|
|
||||||
cmd += ["--model-draft", model_file]
|
cmd += ["--model-draft", str(model_file)]
|
||||||
if shared.args.draft_max > 0:
|
if shared.args.draft_max > 0:
|
||||||
cmd += ["--draft-max", str(shared.args.draft_max)]
|
cmd += ["--draft-max", str(shared.args.draft_max)]
|
||||||
if shared.args.gpu_layers_draft > 0:
|
if shared.args.gpu_layers_draft > 0:
|
||||||
|
|
@ -373,6 +428,13 @@ class LlamaServer:
|
||||||
cmd += ["--device-draft", shared.args.device_draft]
|
cmd += ["--device-draft", shared.args.device_draft]
|
||||||
if shared.args.ctx_size_draft > 0:
|
if shared.args.ctx_size_draft > 0:
|
||||||
cmd += ["--ctx-size-draft", str(shared.args.ctx_size_draft)]
|
cmd += ["--ctx-size-draft", str(shared.args.ctx_size_draft)]
|
||||||
|
if shared.args.spec_type != 'none':
|
||||||
|
cmd += ["--spec-type", shared.args.spec_type]
|
||||||
|
cmd += ["--draft-max", str(shared.args.draft_max)]
|
||||||
|
cmd += ["--spec-ngram-size-n", str(shared.args.spec_ngram_size_n)]
|
||||||
|
cmd += ["--spec-ngram-size-m", str(shared.args.spec_ngram_size_m)]
|
||||||
|
cmd += ["--spec-ngram-min-hits", str(shared.args.spec_ngram_min_hits)]
|
||||||
|
cmd += ["--parallel", str(shared.args.parallel)]
|
||||||
if shared.args.streaming_llm:
|
if shared.args.streaming_llm:
|
||||||
cmd += ["--cache-reuse", "1"]
|
cmd += ["--cache-reuse", "1"]
|
||||||
cmd += ["--swa-full"]
|
cmd += ["--swa-full"]
|
||||||
|
|
@ -385,8 +447,11 @@ class LlamaServer:
|
||||||
extra_flags = extra_flags[1:-1].strip()
|
extra_flags = extra_flags[1:-1].strip()
|
||||||
|
|
||||||
for flag_item in extra_flags.split(','):
|
for flag_item in extra_flags.split(','):
|
||||||
|
flag_item = flag_item.strip()
|
||||||
if '=' in flag_item:
|
if '=' in flag_item:
|
||||||
flag, value = flag_item.split('=', 1)
|
flag, value = flag_item.split('=', 1)
|
||||||
|
flag = flag.strip()
|
||||||
|
value = value.strip()
|
||||||
if len(flag) <= 3:
|
if len(flag) <= 3:
|
||||||
cmd += [f"-{flag}", value]
|
cmd += [f"-{flag}", value]
|
||||||
else:
|
else:
|
||||||
|
|
@ -410,7 +475,9 @@ class LlamaServer:
|
||||||
print(' '.join(str(item) for item in cmd[1:]))
|
print(' '.join(str(item) for item in cmd[1:]))
|
||||||
print()
|
print()
|
||||||
|
|
||||||
logger.info(f"Using gpu_layers={shared.args.gpu_layers} | ctx_size={shared.args.ctx_size} | cache_type={cache_type}")
|
gpu_layers_str = "auto" if shared.args.gpu_layers < 0 else str(shared.args.gpu_layers)
|
||||||
|
ctx_size_str = "auto" if shared.args.ctx_size == 0 and shared.args.gpu_layers < 0 else str(shared.args.ctx_size or 8192)
|
||||||
|
logger.info(f"Using gpu_layers={gpu_layers_str} | ctx_size={ctx_size_str} | cache_type={cache_type}")
|
||||||
# Start the server with pipes for output
|
# Start the server with pipes for output
|
||||||
self.process = subprocess.Popen(
|
self.process = subprocess.Popen(
|
||||||
cmd,
|
cmd,
|
||||||
|
|
@ -434,7 +501,7 @@ class LlamaServer:
|
||||||
response = self.session.get(health_url)
|
response = self.session.get(health_url)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
break
|
break
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
@ -464,6 +531,7 @@ class LlamaServer:
|
||||||
self.process.wait(timeout=5)
|
self.process.wait(timeout=5)
|
||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
self.process.kill()
|
self.process.kill()
|
||||||
|
self.process.wait(timeout=5)
|
||||||
|
|
||||||
self.process = None
|
self.process = None
|
||||||
|
|
||||||
|
|
@ -474,6 +542,8 @@ def filter_stderr_with_progress(process_stderr):
|
||||||
inline (overwriting the same line) until completion.
|
inline (overwriting the same line) until completion.
|
||||||
"""
|
"""
|
||||||
progress_re = re.compile(r'slot update_slots: id.*progress = (\d+\.\d+)')
|
progress_re = re.compile(r'slot update_slots: id.*progress = (\d+\.\d+)')
|
||||||
|
ansi_re = re.compile(r'\x1b\[[0-9;]*[a-zA-Z]')
|
||||||
|
log_prefix_re = re.compile(r'^[IWED] ')
|
||||||
last_was_progress = False
|
last_was_progress = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -492,6 +562,7 @@ def filter_stderr_with_progress(process_stderr):
|
||||||
line_bytes, buffer = buffer.split(b'\n', 1)
|
line_bytes, buffer = buffer.split(b'\n', 1)
|
||||||
try:
|
try:
|
||||||
line = line_bytes.decode('utf-8', errors='replace').strip('\r\n')
|
line = line_bytes.decode('utf-8', errors='replace').strip('\r\n')
|
||||||
|
line = log_prefix_re.sub('', ansi_re.sub('', line))
|
||||||
if line: # Process non-empty lines
|
if line: # Process non-empty lines
|
||||||
match = progress_re.search(line)
|
match = progress_re.search(line)
|
||||||
|
|
||||||
|
|
@ -511,7 +582,7 @@ def filter_stderr_with_progress(process_stderr):
|
||||||
last_was_progress = (progress < 1.0)
|
last_was_progress = (progress < 1.0)
|
||||||
|
|
||||||
# skip noise lines
|
# skip noise lines
|
||||||
elif not (line.startswith(('srv ', 'slot ')) or 'log_server_r: request: GET /health' in line):
|
elif not (line.startswith(('srv ', 'slot ')) or 'log_server_r: request: GET /health' in line or 'No parser definition detected' in line):
|
||||||
# if we were in progress, finish that line first
|
# if we were in progress, finish that line first
|
||||||
if last_was_progress:
|
if last_was_progress:
|
||||||
print(file=sys.stderr)
|
print(file=sys.stderr)
|
||||||
|
|
@ -527,5 +598,5 @@ def filter_stderr_with_progress(process_stderr):
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
process_stderr.close()
|
process_stderr.close()
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,10 @@
|
||||||
import functools
|
import functools
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
|
|
||||||
loaders_and_params = OrderedDict({
|
loaders_and_params = OrderedDict({
|
||||||
'llama.cpp': [
|
'llama.cpp': [
|
||||||
'gpu_layers',
|
'gpu_layers',
|
||||||
|
'fit_target',
|
||||||
'cpu_moe',
|
'cpu_moe',
|
||||||
'threads',
|
'threads',
|
||||||
'threads_batch',
|
'threads_batch',
|
||||||
|
|
@ -16,18 +15,22 @@ loaders_and_params = OrderedDict({
|
||||||
'tensor_split',
|
'tensor_split',
|
||||||
'extra_flags',
|
'extra_flags',
|
||||||
'streaming_llm',
|
'streaming_llm',
|
||||||
'rope_freq_base',
|
|
||||||
'compress_pos_emb',
|
|
||||||
'row_split',
|
'row_split',
|
||||||
'no_kv_offload',
|
'no_kv_offload',
|
||||||
'no_mmap',
|
'no_mmap',
|
||||||
'mlock',
|
'mlock',
|
||||||
'numa',
|
'numa',
|
||||||
|
'parallel',
|
||||||
'model_draft',
|
'model_draft',
|
||||||
'draft_max',
|
'draft_max',
|
||||||
'gpu_layers_draft',
|
'gpu_layers_draft',
|
||||||
'device_draft',
|
'device_draft',
|
||||||
'ctx_size_draft',
|
'ctx_size_draft',
|
||||||
|
'ngram_header',
|
||||||
|
'spec_type',
|
||||||
|
'spec_ngram_size_n',
|
||||||
|
'spec_ngram_size_m',
|
||||||
|
'spec_ngram_min_hits',
|
||||||
'speculative_decoding_accordion',
|
'speculative_decoding_accordion',
|
||||||
'mmproj',
|
'mmproj',
|
||||||
'mmproj_accordion',
|
'mmproj_accordion',
|
||||||
|
|
@ -36,8 +39,6 @@ loaders_and_params = OrderedDict({
|
||||||
'Transformers': [
|
'Transformers': [
|
||||||
'gpu_split',
|
'gpu_split',
|
||||||
'cpu_memory',
|
'cpu_memory',
|
||||||
'alpha_value',
|
|
||||||
'compress_pos_emb',
|
|
||||||
'compute_dtype',
|
'compute_dtype',
|
||||||
'quant_type',
|
'quant_type',
|
||||||
'load_in_8bit',
|
'load_in_8bit',
|
||||||
|
|
@ -64,46 +65,12 @@ loaders_and_params = OrderedDict({
|
||||||
'gpu_split',
|
'gpu_split',
|
||||||
'model_draft',
|
'model_draft',
|
||||||
'draft_max',
|
'draft_max',
|
||||||
'ctx_size_draft',
|
|
||||||
'speculative_decoding_accordion',
|
'speculative_decoding_accordion',
|
||||||
'enable_tp',
|
'enable_tp',
|
||||||
'tp_backend',
|
'tp_backend',
|
||||||
],
|
],
|
||||||
'ExLlamav2_HF': [
|
|
||||||
'ctx_size',
|
|
||||||
'cache_type',
|
|
||||||
'gpu_split',
|
|
||||||
'alpha_value',
|
|
||||||
'compress_pos_emb',
|
|
||||||
'num_experts_per_token',
|
|
||||||
'autosplit',
|
|
||||||
'enable_tp',
|
|
||||||
'no_flash_attn',
|
|
||||||
'no_xformers',
|
|
||||||
'no_sdpa',
|
|
||||||
'cfg_cache',
|
|
||||||
'no_use_fast',
|
|
||||||
],
|
|
||||||
'ExLlamav2': [
|
|
||||||
'ctx_size',
|
|
||||||
'cache_type',
|
|
||||||
'gpu_split',
|
|
||||||
'alpha_value',
|
|
||||||
'compress_pos_emb',
|
|
||||||
'num_experts_per_token',
|
|
||||||
'autosplit',
|
|
||||||
'enable_tp',
|
|
||||||
'no_flash_attn',
|
|
||||||
'no_xformers',
|
|
||||||
'no_sdpa',
|
|
||||||
'model_draft',
|
|
||||||
'draft_max',
|
|
||||||
'ctx_size_draft',
|
|
||||||
'speculative_decoding_accordion',
|
|
||||||
],
|
|
||||||
'TensorRT-LLM': [
|
'TensorRT-LLM': [
|
||||||
'ctx_size',
|
'ctx_size',
|
||||||
'cpp_runner',
|
|
||||||
'tensorrt_llm_info',
|
'tensorrt_llm_info',
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
|
|
@ -128,6 +95,8 @@ def transformers_samplers():
|
||||||
'tfs',
|
'tfs',
|
||||||
'top_a',
|
'top_a',
|
||||||
'top_n_sigma',
|
'top_n_sigma',
|
||||||
|
'adaptive_target',
|
||||||
|
'adaptive_decay',
|
||||||
'dry_multiplier',
|
'dry_multiplier',
|
||||||
'dry_allowed_length',
|
'dry_allowed_length',
|
||||||
'dry_base',
|
'dry_base',
|
||||||
|
|
@ -183,54 +152,8 @@ loaders_samplers = {
|
||||||
'tfs',
|
'tfs',
|
||||||
'top_a',
|
'top_a',
|
||||||
'top_n_sigma',
|
'top_n_sigma',
|
||||||
'dry_multiplier',
|
'adaptive_target',
|
||||||
'dry_allowed_length',
|
'adaptive_decay',
|
||||||
'dry_base',
|
|
||||||
'repetition_penalty',
|
|
||||||
'frequency_penalty',
|
|
||||||
'presence_penalty',
|
|
||||||
'encoder_repetition_penalty',
|
|
||||||
'no_repeat_ngram_size',
|
|
||||||
'repetition_penalty_range',
|
|
||||||
'guidance_scale',
|
|
||||||
'mirostat_mode',
|
|
||||||
'mirostat_tau',
|
|
||||||
'mirostat_eta',
|
|
||||||
'do_sample',
|
|
||||||
'dynamic_temperature',
|
|
||||||
'temperature_last',
|
|
||||||
'auto_max_new_tokens',
|
|
||||||
'ban_eos_token',
|
|
||||||
'add_bos_token',
|
|
||||||
'enable_thinking',
|
|
||||||
'reasoning_effort',
|
|
||||||
'skip_special_tokens',
|
|
||||||
'seed',
|
|
||||||
'sampler_priority',
|
|
||||||
'custom_token_bans',
|
|
||||||
'negative_prompt',
|
|
||||||
'dry_sequence_breakers',
|
|
||||||
'grammar_string',
|
|
||||||
'grammar_file_row',
|
|
||||||
},
|
|
||||||
'ExLlamav2_HF': {
|
|
||||||
'temperature',
|
|
||||||
'dynatemp_low',
|
|
||||||
'dynatemp_high',
|
|
||||||
'dynatemp_exponent',
|
|
||||||
'smoothing_factor',
|
|
||||||
'smoothing_curve',
|
|
||||||
'min_p',
|
|
||||||
'top_p',
|
|
||||||
'top_k',
|
|
||||||
'typical_p',
|
|
||||||
'xtc_threshold',
|
|
||||||
'xtc_probability',
|
|
||||||
'epsilon_cutoff',
|
|
||||||
'eta_cutoff',
|
|
||||||
'tfs',
|
|
||||||
'top_a',
|
|
||||||
'top_n_sigma',
|
|
||||||
'dry_multiplier',
|
'dry_multiplier',
|
||||||
'dry_allowed_length',
|
'dry_allowed_length',
|
||||||
'dry_base',
|
'dry_base',
|
||||||
|
|
@ -266,6 +189,8 @@ loaders_samplers = {
|
||||||
'min_p',
|
'min_p',
|
||||||
'top_p',
|
'top_p',
|
||||||
'top_k',
|
'top_k',
|
||||||
|
'adaptive_target',
|
||||||
|
'adaptive_decay',
|
||||||
'repetition_penalty',
|
'repetition_penalty',
|
||||||
'frequency_penalty',
|
'frequency_penalty',
|
||||||
'presence_penalty',
|
'presence_penalty',
|
||||||
|
|
@ -276,44 +201,9 @@ loaders_samplers = {
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
'add_bos_token',
|
'add_bos_token',
|
||||||
'enable_thinking',
|
'enable_thinking',
|
||||||
'seed',
|
|
||||||
'skip_special_tokens',
|
|
||||||
},
|
|
||||||
'ExLlamav2': {
|
|
||||||
'temperature',
|
|
||||||
'dynatemp_low',
|
|
||||||
'dynatemp_high',
|
|
||||||
'dynatemp_exponent',
|
|
||||||
'smoothing_factor',
|
|
||||||
'min_p',
|
|
||||||
'top_p',
|
|
||||||
'top_k',
|
|
||||||
'typical_p',
|
|
||||||
'xtc_threshold',
|
|
||||||
'xtc_probability',
|
|
||||||
'tfs',
|
|
||||||
'top_a',
|
|
||||||
'dry_multiplier',
|
|
||||||
'dry_allowed_length',
|
|
||||||
'dry_base',
|
|
||||||
'repetition_penalty',
|
|
||||||
'frequency_penalty',
|
|
||||||
'presence_penalty',
|
|
||||||
'repetition_penalty_range',
|
|
||||||
'mirostat_mode',
|
|
||||||
'mirostat_tau',
|
|
||||||
'mirostat_eta',
|
|
||||||
'dynamic_temperature',
|
|
||||||
'temperature_last',
|
|
||||||
'auto_max_new_tokens',
|
|
||||||
'ban_eos_token',
|
|
||||||
'add_bos_token',
|
|
||||||
'enable_thinking',
|
|
||||||
'reasoning_effort',
|
'reasoning_effort',
|
||||||
'skip_special_tokens',
|
|
||||||
'seed',
|
'seed',
|
||||||
'custom_token_bans',
|
'skip_special_tokens',
|
||||||
'dry_sequence_breakers',
|
|
||||||
},
|
},
|
||||||
'llama.cpp': {
|
'llama.cpp': {
|
||||||
'temperature',
|
'temperature',
|
||||||
|
|
@ -327,6 +217,8 @@ loaders_samplers = {
|
||||||
'xtc_threshold',
|
'xtc_threshold',
|
||||||
'xtc_probability',
|
'xtc_probability',
|
||||||
'top_n_sigma',
|
'top_n_sigma',
|
||||||
|
'adaptive_target',
|
||||||
|
'adaptive_decay',
|
||||||
'dry_multiplier',
|
'dry_multiplier',
|
||||||
'dry_allowed_length',
|
'dry_allowed_length',
|
||||||
'dry_base',
|
'dry_base',
|
||||||
|
|
@ -346,6 +238,7 @@ loaders_samplers = {
|
||||||
'reasoning_effort',
|
'reasoning_effort',
|
||||||
'seed',
|
'seed',
|
||||||
'sampler_priority',
|
'sampler_priority',
|
||||||
|
'custom_token_bans',
|
||||||
'dry_sequence_breakers',
|
'dry_sequence_breakers',
|
||||||
'grammar_string',
|
'grammar_string',
|
||||||
'grammar_file_row',
|
'grammar_file_row',
|
||||||
|
|
@ -354,11 +247,16 @@ loaders_samplers = {
|
||||||
'temperature',
|
'temperature',
|
||||||
'top_p',
|
'top_p',
|
||||||
'top_k',
|
'top_k',
|
||||||
|
'min_p',
|
||||||
'repetition_penalty',
|
'repetition_penalty',
|
||||||
'frequency_penalty',
|
'frequency_penalty',
|
||||||
'presence_penalty',
|
'presence_penalty',
|
||||||
|
'no_repeat_ngram_size',
|
||||||
'auto_max_new_tokens',
|
'auto_max_new_tokens',
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
|
'add_bos_token',
|
||||||
|
'skip_special_tokens',
|
||||||
|
'seed',
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -374,6 +272,7 @@ def list_all_samplers():
|
||||||
|
|
||||||
|
|
||||||
def blacklist_samplers(loader, dynamic_temperature):
|
def blacklist_samplers(loader, dynamic_temperature):
|
||||||
|
import gradio as gr
|
||||||
all_samplers = list_all_samplers()
|
all_samplers = list_all_samplers()
|
||||||
output = []
|
output = []
|
||||||
|
|
||||||
|
|
@ -399,7 +298,58 @@ def get_all_params():
|
||||||
return sorted(all_params)
|
return sorted(all_params)
|
||||||
|
|
||||||
|
|
||||||
|
def list_model_elements():
|
||||||
|
return [
|
||||||
|
'filter_by_loader',
|
||||||
|
'loader',
|
||||||
|
'cpu_memory',
|
||||||
|
'gpu_layers',
|
||||||
|
'fit_target',
|
||||||
|
'cpu_moe',
|
||||||
|
'threads',
|
||||||
|
'threads_batch',
|
||||||
|
'batch_size',
|
||||||
|
'ubatch_size',
|
||||||
|
'ctx_size',
|
||||||
|
'cache_type',
|
||||||
|
'tensor_split',
|
||||||
|
'extra_flags',
|
||||||
|
'streaming_llm',
|
||||||
|
'gpu_split',
|
||||||
|
'compute_dtype',
|
||||||
|
'quant_type',
|
||||||
|
'load_in_8bit',
|
||||||
|
'load_in_4bit',
|
||||||
|
'attn_implementation',
|
||||||
|
'cpu',
|
||||||
|
'disk',
|
||||||
|
'row_split',
|
||||||
|
'no_kv_offload',
|
||||||
|
'no_mmap',
|
||||||
|
'mlock',
|
||||||
|
'numa',
|
||||||
|
'parallel',
|
||||||
|
'use_double_quant',
|
||||||
|
'bf16',
|
||||||
|
'enable_tp',
|
||||||
|
'tp_backend',
|
||||||
|
'cfg_cache',
|
||||||
|
'no_use_fast',
|
||||||
|
'model_draft',
|
||||||
|
'draft_max',
|
||||||
|
'gpu_layers_draft',
|
||||||
|
'device_draft',
|
||||||
|
'ctx_size_draft',
|
||||||
|
'spec_type',
|
||||||
|
'spec_ngram_size_n',
|
||||||
|
'spec_ngram_size_m',
|
||||||
|
'spec_ngram_min_hits',
|
||||||
|
'mmproj',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def make_loader_params_visible(loader):
|
def make_loader_params_visible(loader):
|
||||||
|
import gradio as gr
|
||||||
params = []
|
params = []
|
||||||
all_params = get_all_params()
|
all_params = get_all_params()
|
||||||
if loader in loaders_and_params:
|
if loader in loaders_and_params:
|
||||||
|
|
|
||||||
|
|
@ -70,26 +70,21 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
|
||||||
from modules import sampler_hijack
|
from modules import sampler_hijack
|
||||||
from modules.torch_utils import get_device
|
from modules.torch_utils import get_device
|
||||||
|
|
||||||
is_non_hf_exllamav2 = shared.model.__class__.__name__ == 'Exllamav2Model'
|
|
||||||
is_non_hf_exllamav3 = shared.model.__class__.__name__ == 'Exllamav3Model'
|
is_non_hf_exllamav3 = shared.model.__class__.__name__ == 'Exllamav3Model'
|
||||||
|
|
||||||
if not use_samplers:
|
if not use_samplers:
|
||||||
state = {'stream': True}
|
state = {'stream': True}
|
||||||
|
|
||||||
if use_samplers:
|
if use_samplers:
|
||||||
if is_non_hf_exllamav2:
|
|
||||||
# sampling is all done in C++ for exllama, so it is really hard to hijack
|
|
||||||
logger.error("Sampler hijacking is not supported non-Huggingface loaders.")
|
|
||||||
return 'Error: Sampler hijacking is not supported non-Huggingface loaders. Please disable the "Use samplers" option.', previous
|
|
||||||
|
|
||||||
state['max_new_tokens'] = 1
|
state['max_new_tokens'] = 1
|
||||||
state['auto_max_new_tokens'] = False
|
state['auto_max_new_tokens'] = False
|
||||||
|
state.setdefault('stream', True)
|
||||||
for _ in generate_reply(prompt, state):
|
for _ in generate_reply(prompt, state):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
scores = sampler_hijack.global_scores[-1]
|
scores = sampler_hijack.global_scores[-1]
|
||||||
else:
|
else:
|
||||||
if is_non_hf_exllamav2 or is_non_hf_exllamav3:
|
if is_non_hf_exllamav3:
|
||||||
device = get_device()
|
device = get_device()
|
||||||
tokens = shared.tokenizer.encode(prompt)
|
tokens = shared.tokenizer.encode(prompt)
|
||||||
if device:
|
if device:
|
||||||
|
|
@ -105,7 +100,7 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
|
||||||
output = shared.model(input_ids=tokens)
|
output = shared.model(input_ids=tokens)
|
||||||
scores = output['logits'][-1][-1]
|
scores = output['logits'][-1][-1]
|
||||||
|
|
||||||
probs = torch.softmax(scores, dim=-1, dtype=torch.float)
|
probs = torch.softmax(scores.detach(), dim=-1, dtype=torch.float)
|
||||||
topk_values, topk_indices = torch.topk(probs, k=top_logits, largest=True, sorted=True)
|
topk_values, topk_indices = torch.topk(probs, k=top_logits, largest=True, sorted=True)
|
||||||
if hasattr(shared.tokenizer, 'convert_ids_to_tokens'):
|
if hasattr(shared.tokenizer, 'convert_ids_to_tokens'):
|
||||||
tokens = [shared.tokenizer.convert_ids_to_tokens(int(i)) for i in topk_indices]
|
tokens = [shared.tokenizer.convert_ids_to_tokens(int(i)) for i in topk_indices]
|
||||||
|
|
@ -120,7 +115,7 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
|
||||||
if isinstance(key, bytes):
|
if isinstance(key, bytes):
|
||||||
try:
|
try:
|
||||||
key = key.decode()
|
key = key.decode()
|
||||||
except:
|
except Exception:
|
||||||
key = key.decode('latin')
|
key = key.decode('latin')
|
||||||
|
|
||||||
output[key] = row[0]
|
output[key] = row[0]
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ def get_single(value_type, file):
|
||||||
value = file.read(value_length)
|
value = file.read(value_length)
|
||||||
try:
|
try:
|
||||||
value = value.decode('utf-8')
|
value = value.decode('utf-8')
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
type_str = _simple_value_packing.get(value_type)
|
type_str = _simple_value_packing.get(value_type)
|
||||||
|
|
|
||||||
|
|
@ -20,8 +20,6 @@ def load_model(model_name, loader=None):
|
||||||
'Transformers': transformers_loader,
|
'Transformers': transformers_loader,
|
||||||
'ExLlamav3_HF': ExLlamav3_HF_loader,
|
'ExLlamav3_HF': ExLlamav3_HF_loader,
|
||||||
'ExLlamav3': ExLlamav3_loader,
|
'ExLlamav3': ExLlamav3_loader,
|
||||||
'ExLlamav2_HF': ExLlamav2_HF_loader,
|
|
||||||
'ExLlamav2': ExLlamav2_loader,
|
|
||||||
'TensorRT-LLM': TensorRT_LLM_loader,
|
'TensorRT-LLM': TensorRT_LLM_loader,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -40,6 +38,9 @@ def load_model(model_name, loader=None):
|
||||||
sampler_hijack.hijack_samplers()
|
sampler_hijack.hijack_samplers()
|
||||||
|
|
||||||
shared.args.loader = loader
|
shared.args.loader = loader
|
||||||
|
if loader != 'llama.cpp' and shared.args.ctx_size == 0:
|
||||||
|
shared.args.ctx_size = 8192
|
||||||
|
|
||||||
output = load_func_map[loader](model_name)
|
output = load_func_map[loader](model_name)
|
||||||
if type(output) is tuple:
|
if type(output) is tuple:
|
||||||
model, tokenizer = output
|
model, tokenizer = output
|
||||||
|
|
@ -54,7 +55,10 @@ def load_model(model_name, loader=None):
|
||||||
|
|
||||||
shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings})
|
shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings})
|
||||||
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp':
|
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp':
|
||||||
|
if shared.args.ctx_size > 0:
|
||||||
shared.settings['truncation_length'] = shared.args.ctx_size
|
shared.settings['truncation_length'] = shared.args.ctx_size
|
||||||
|
elif loader == 'llama.cpp' and hasattr(model, 'n_ctx') and model.n_ctx:
|
||||||
|
shared.settings['truncation_length'] = model.n_ctx
|
||||||
|
|
||||||
shared.is_multimodal = False
|
shared.is_multimodal = False
|
||||||
if loader.lower() in ('exllamav3', 'llama.cpp') and hasattr(model, 'is_multimodal'):
|
if loader.lower() in ('exllamav3', 'llama.cpp') and hasattr(model, 'is_multimodal'):
|
||||||
|
|
@ -108,19 +112,6 @@ def ExLlamav3_loader(model_name):
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def ExLlamav2_HF_loader(model_name):
|
|
||||||
from modules.exllamav2_hf import Exllamav2HF
|
|
||||||
|
|
||||||
return Exllamav2HF.from_pretrained(model_name)
|
|
||||||
|
|
||||||
|
|
||||||
def ExLlamav2_loader(model_name):
|
|
||||||
from modules.exllamav2 import Exllamav2Model
|
|
||||||
|
|
||||||
model, tokenizer = Exllamav2Model.from_pretrained(model_name)
|
|
||||||
return model, tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def TensorRT_LLM_loader(model_name):
|
def TensorRT_LLM_loader(model_name):
|
||||||
try:
|
try:
|
||||||
from modules.tensorrt_llm import TensorRTLLMModel
|
from modules.tensorrt_llm import TensorRTLLMModel
|
||||||
|
|
@ -128,7 +119,7 @@ def TensorRT_LLM_loader(model_name):
|
||||||
raise ModuleNotFoundError("Failed to import 'tensorrt_llm'. Please install it manually following the instructions in the TensorRT-LLM GitHub repository.")
|
raise ModuleNotFoundError("Failed to import 'tensorrt_llm'. Please install it manually following the instructions in the TensorRT-LLM GitHub repository.")
|
||||||
|
|
||||||
model = TensorRTLLMModel.from_pretrained(model_name)
|
model = TensorRTLLMModel.from_pretrained(model_name)
|
||||||
return model
|
return model, model.tokenizer
|
||||||
|
|
||||||
|
|
||||||
def unload_model(keep_model_name=False):
|
def unload_model(keep_model_name=False):
|
||||||
|
|
@ -138,10 +129,10 @@ def unload_model(keep_model_name=False):
|
||||||
model_class_name = shared.model.__class__.__name__
|
model_class_name = shared.model.__class__.__name__
|
||||||
is_llamacpp = (model_class_name == 'LlamaServer')
|
is_llamacpp = (model_class_name == 'LlamaServer')
|
||||||
|
|
||||||
if model_class_name in ['Exllamav3Model', 'Exllamav3HF']:
|
if model_class_name in ['Exllamav3Model', 'Exllamav3HF', 'TensorRTLLMModel']:
|
||||||
shared.model.unload()
|
|
||||||
elif model_class_name in ['Exllamav2Model', 'Exllamav2HF'] and hasattr(shared.model, 'unload'):
|
|
||||||
shared.model.unload()
|
shared.model.unload()
|
||||||
|
elif model_class_name == 'LlamaServer':
|
||||||
|
shared.model.stop()
|
||||||
|
|
||||||
shared.model = shared.tokenizer = None
|
shared.model = shared.tokenizer = None
|
||||||
shared.lora_names = []
|
shared.lora_names = []
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,12 @@
|
||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import subprocess
|
|
||||||
from math import floor
|
from math import floor
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from modules import chat, loaders, metadata_gguf, shared, ui
|
from modules import loaders, metadata_gguf, shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.utils import resolve_model_path
|
from modules.utils import resolve_model_path
|
||||||
|
|
||||||
|
|
@ -17,9 +15,6 @@ def get_fallback_settings():
|
||||||
return {
|
return {
|
||||||
'bf16': False,
|
'bf16': False,
|
||||||
'ctx_size': 8192,
|
'ctx_size': 8192,
|
||||||
'rope_freq_base': 0,
|
|
||||||
'compress_pos_emb': 1,
|
|
||||||
'alpha_value': 1,
|
|
||||||
'truncation_length': shared.settings['truncation_length'],
|
'truncation_length': shared.settings['truncation_length'],
|
||||||
'truncation_length_info': shared.settings['truncation_length'],
|
'truncation_length_info': shared.settings['truncation_length'],
|
||||||
'skip_special_tokens': shared.settings['skip_special_tokens'],
|
'skip_special_tokens': shared.settings['skip_special_tokens'],
|
||||||
|
|
@ -69,21 +64,19 @@ def get_model_metadata(model):
|
||||||
|
|
||||||
for k in metadata:
|
for k in metadata:
|
||||||
if k.endswith('.context_length'):
|
if k.endswith('.context_length'):
|
||||||
model_settings['ctx_size'] = min(metadata[k], 8192)
|
model_settings['ctx_size'] = 0
|
||||||
model_settings['truncation_length_info'] = metadata[k]
|
model_settings['truncation_length_info'] = metadata[k]
|
||||||
elif k.endswith('rope.freq_base'):
|
|
||||||
model_settings['rope_freq_base'] = metadata[k]
|
|
||||||
elif k.endswith('rope.scale_linear'):
|
|
||||||
model_settings['compress_pos_emb'] = metadata[k]
|
|
||||||
elif k.endswith('rope.scaling.factor'):
|
|
||||||
model_settings['compress_pos_emb'] = metadata[k]
|
|
||||||
elif k.endswith('.block_count'):
|
elif k.endswith('.block_count'):
|
||||||
model_settings['gpu_layers'] = metadata[k] + 1
|
model_settings['gpu_layers'] = -1
|
||||||
model_settings['max_gpu_layers'] = metadata[k] + 1
|
model_settings['max_gpu_layers'] = metadata[k] + 1
|
||||||
|
|
||||||
if 'tokenizer.chat_template' in metadata:
|
if 'tokenizer.chat_template' in metadata:
|
||||||
template = metadata['tokenizer.chat_template']
|
template = metadata['tokenizer.chat_template']
|
||||||
|
if 'tokenizer.ggml.eos_token_id' in metadata:
|
||||||
eos_token = metadata['tokenizer.ggml.tokens'][metadata['tokenizer.ggml.eos_token_id']]
|
eos_token = metadata['tokenizer.ggml.tokens'][metadata['tokenizer.ggml.eos_token_id']]
|
||||||
|
else:
|
||||||
|
eos_token = ""
|
||||||
|
|
||||||
if 'tokenizer.ggml.bos_token_id' in metadata:
|
if 'tokenizer.ggml.bos_token_id' in metadata:
|
||||||
bos_token = metadata['tokenizer.ggml.tokens'][metadata['tokenizer.ggml.bos_token_id']]
|
bos_token = metadata['tokenizer.ggml.tokens'][metadata['tokenizer.ggml.bos_token_id']]
|
||||||
else:
|
else:
|
||||||
|
|
@ -117,15 +110,6 @@ def get_model_metadata(model):
|
||||||
model_settings['ctx_size'] = min(value, 8192)
|
model_settings['ctx_size'] = min(value, 8192)
|
||||||
break
|
break
|
||||||
|
|
||||||
if 'rope_theta' in metadata:
|
|
||||||
model_settings['rope_freq_base'] = metadata['rope_theta']
|
|
||||||
elif 'attn_config' in metadata and 'rope_theta' in metadata['attn_config']:
|
|
||||||
model_settings['rope_freq_base'] = metadata['attn_config']['rope_theta']
|
|
||||||
|
|
||||||
if 'rope_scaling' in metadata and isinstance(metadata['rope_scaling'], dict) and all(key in metadata['rope_scaling'] for key in ('type', 'factor')):
|
|
||||||
if metadata['rope_scaling']['type'] == 'linear':
|
|
||||||
model_settings['compress_pos_emb'] = metadata['rope_scaling']['factor']
|
|
||||||
|
|
||||||
if 'torch_dtype' in metadata and metadata['torch_dtype'] == 'bfloat16':
|
if 'torch_dtype' in metadata and metadata['torch_dtype'] == 'bfloat16':
|
||||||
model_settings['bf16'] = True
|
model_settings['bf16'] = True
|
||||||
|
|
||||||
|
|
@ -179,10 +163,6 @@ def get_model_metadata(model):
|
||||||
if 'instruction_template' not in model_settings:
|
if 'instruction_template' not in model_settings:
|
||||||
model_settings['instruction_template'] = 'Alpaca'
|
model_settings['instruction_template'] = 'Alpaca'
|
||||||
|
|
||||||
# Ignore rope_freq_base if set to the default value
|
|
||||||
if 'rope_freq_base' in model_settings and model_settings['rope_freq_base'] == 10000:
|
|
||||||
model_settings.pop('rope_freq_base')
|
|
||||||
|
|
||||||
# Apply user settings from user_data/models/config-user.yaml
|
# Apply user settings from user_data/models/config-user.yaml
|
||||||
settings = shared.user_config
|
settings = shared.user_config
|
||||||
for pat in settings:
|
for pat in settings:
|
||||||
|
|
@ -196,7 +176,7 @@ def get_model_metadata(model):
|
||||||
|
|
||||||
# Load instruction template if defined by name rather than by value
|
# Load instruction template if defined by name rather than by value
|
||||||
if model_settings['instruction_template'] != 'Custom (obtained from model metadata)':
|
if model_settings['instruction_template'] != 'Custom (obtained from model metadata)':
|
||||||
model_settings['instruction_template_str'] = chat.load_instruction_template(model_settings['instruction_template'])
|
model_settings['instruction_template_str'] = load_instruction_template(model_settings['instruction_template'])
|
||||||
|
|
||||||
return model_settings
|
return model_settings
|
||||||
|
|
||||||
|
|
@ -213,12 +193,8 @@ def infer_loader(model_name, model_settings, hf_quant_method=None):
|
||||||
loader = 'llama.cpp'
|
loader = 'llama.cpp'
|
||||||
elif hf_quant_method == 'exl3':
|
elif hf_quant_method == 'exl3':
|
||||||
loader = 'ExLlamav3'
|
loader = 'ExLlamav3'
|
||||||
elif hf_quant_method in ['exl2', 'gptq']:
|
|
||||||
loader = 'ExLlamav2_HF'
|
|
||||||
elif re.match(r'.*exl3', model_name.lower()):
|
elif re.match(r'.*exl3', model_name.lower()):
|
||||||
loader = 'ExLlamav3'
|
loader = 'ExLlamav3'
|
||||||
elif re.match(r'.*exl2', model_name.lower()):
|
|
||||||
loader = 'ExLlamav2_HF'
|
|
||||||
else:
|
else:
|
||||||
loader = 'Transformers'
|
loader = 'Transformers'
|
||||||
|
|
||||||
|
|
@ -229,7 +205,7 @@ def update_model_parameters(state, initial=False):
|
||||||
'''
|
'''
|
||||||
UI: update the command-line arguments based on the interface values
|
UI: update the command-line arguments based on the interface values
|
||||||
'''
|
'''
|
||||||
elements = ui.list_model_elements() # the names of the parameters
|
elements = loaders.list_model_elements() # the names of the parameters
|
||||||
|
|
||||||
for i, element in enumerate(elements):
|
for i, element in enumerate(elements):
|
||||||
if element not in state:
|
if element not in state:
|
||||||
|
|
@ -249,10 +225,11 @@ def apply_model_settings_to_state(model, state):
|
||||||
'''
|
'''
|
||||||
UI: update the state variable with the model settings
|
UI: update the state variable with the model settings
|
||||||
'''
|
'''
|
||||||
|
import gradio as gr
|
||||||
model_settings = get_model_metadata(model)
|
model_settings = get_model_metadata(model)
|
||||||
if 'loader' in model_settings:
|
if 'loader' in model_settings:
|
||||||
loader = model_settings.pop('loader')
|
loader = model_settings.pop('loader')
|
||||||
if not ((loader == 'ExLlamav2_HF' and state['loader'] == 'ExLlamav2') or (loader == 'ExLlamav3_HF' and state['loader'] == 'ExLlamav3')):
|
if not (loader == 'ExLlamav3_HF' and state['loader'] == 'ExLlamav3'):
|
||||||
state['loader'] = loader
|
state['loader'] = loader
|
||||||
|
|
||||||
for k in model_settings:
|
for k in model_settings:
|
||||||
|
|
@ -261,16 +238,18 @@ def apply_model_settings_to_state(model, state):
|
||||||
|
|
||||||
# Handle GPU layers and VRAM update for llama.cpp
|
# Handle GPU layers and VRAM update for llama.cpp
|
||||||
if state['loader'] == 'llama.cpp' and 'gpu_layers' in model_settings:
|
if state['loader'] == 'llama.cpp' and 'gpu_layers' in model_settings:
|
||||||
vram_info, gpu_layers_update = update_gpu_layers_and_vram(
|
gpu_layers = model_settings['gpu_layers'] # -1 (auto) by default, or user-saved value
|
||||||
|
max_layers = model_settings.get('max_gpu_layers', 256)
|
||||||
|
state['gpu_layers'] = gr.update(value=gpu_layers, maximum=max_layers)
|
||||||
|
|
||||||
|
vram_info = update_gpu_layers_and_vram(
|
||||||
state['loader'],
|
state['loader'],
|
||||||
model,
|
model,
|
||||||
model_settings['gpu_layers'],
|
gpu_layers,
|
||||||
state['ctx_size'],
|
state['ctx_size'],
|
||||||
state['cache_type'],
|
state['cache_type'],
|
||||||
auto_adjust=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
state['gpu_layers'] = gpu_layers_update
|
|
||||||
state['vram_info'] = vram_info
|
state['vram_info'] = vram_info
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
@ -289,7 +268,7 @@ def save_model_settings(model, state):
|
||||||
if model_regex not in user_config:
|
if model_regex not in user_config:
|
||||||
user_config[model_regex] = {}
|
user_config[model_regex] = {}
|
||||||
|
|
||||||
for k in ui.list_model_elements():
|
for k in loaders.list_model_elements():
|
||||||
if k == 'loader' or k in loaders.loaders_and_params[state['loader']]:
|
if k == 'loader' or k in loaders.loaders_and_params[state['loader']]:
|
||||||
user_config[model_regex][k] = state[k]
|
user_config[model_regex][k] = state[k]
|
||||||
|
|
||||||
|
|
@ -408,120 +387,113 @@ def estimate_vram(gguf_file, gpu_layers, ctx_size, cache_type):
|
||||||
return vram
|
return vram
|
||||||
|
|
||||||
|
|
||||||
def get_nvidia_vram(return_free=True):
|
def update_gpu_layers_and_vram(loader, model, gpu_layers, ctx_size, cache_type):
|
||||||
"""
|
"""
|
||||||
Calculates VRAM statistics across all NVIDIA GPUs by parsing nvidia-smi output.
|
Compute the estimated VRAM usage for the given GPU layers and return
|
||||||
|
an HTML string for the UI display.
|
||||||
Args:
|
|
||||||
return_free (bool): If True, returns free VRAM. If False, returns total VRAM.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: Either the total free VRAM or total VRAM in MiB summed across all detected NVIDIA GPUs.
|
|
||||||
Returns -1 if nvidia-smi command fails (not found, error, etc.).
|
|
||||||
Returns 0 if nvidia-smi succeeds but no GPU memory info found.
|
|
||||||
"""
|
"""
|
||||||
try:
|
if loader != 'llama.cpp' or model in ["None", None] or not model.endswith(".gguf") or gpu_layers < 0 or ctx_size == 0:
|
||||||
# Execute nvidia-smi command
|
return f"<div id=\"vram-info\"'>Estimated VRAM to load the model: <span class=\"value\">auto</span></div>"
|
||||||
result = subprocess.run(
|
|
||||||
['nvidia-smi'],
|
vram_usage = estimate_vram(model, gpu_layers, ctx_size, cache_type)
|
||||||
capture_output=True,
|
return f"<div id=\"vram-info\"'>Estimated VRAM to load the model: <span class=\"value\">{vram_usage:.0f} MiB</span></div>"
|
||||||
text=True,
|
|
||||||
check=False
|
|
||||||
|
def load_instruction_template(template):
|
||||||
|
if template == 'None':
|
||||||
|
return ''
|
||||||
|
|
||||||
|
for filepath in [shared.user_data_dir / 'instruction-templates' / f'{template}.yaml', shared.user_data_dir / 'instruction-templates' / 'Alpaca.yaml']:
|
||||||
|
if filepath.exists():
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return ''
|
||||||
|
|
||||||
|
with open(filepath, 'r', encoding='utf-8') as f:
|
||||||
|
file_contents = f.read()
|
||||||
|
data = yaml.safe_load(file_contents)
|
||||||
|
if 'instruction_template' in data:
|
||||||
|
return data['instruction_template']
|
||||||
|
else:
|
||||||
|
return _jinja_template_from_old_format(data)
|
||||||
|
|
||||||
|
|
||||||
|
def _jinja_template_from_old_format(params, verbose=False):
|
||||||
|
MASTER_TEMPLATE = """
|
||||||
|
{%- set ns = namespace(found=false) -%}
|
||||||
|
{%- for message in messages -%}
|
||||||
|
{%- if message['role'] == 'system' -%}
|
||||||
|
{%- set ns.found = true -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- if not ns.found -%}
|
||||||
|
{{- '<|PRE-SYSTEM|>' + '<|SYSTEM-MESSAGE|>' + '<|POST-SYSTEM|>' -}}
|
||||||
|
{%- endif %}
|
||||||
|
{%- for message in messages %}
|
||||||
|
{%- if message['role'] == 'system' -%}
|
||||||
|
{{- '<|PRE-SYSTEM|>' + message['content'] + '<|POST-SYSTEM|>' -}}
|
||||||
|
{%- else -%}
|
||||||
|
{%- if message['role'] == 'user' -%}
|
||||||
|
{{-'<|PRE-USER|>' + message['content'] + '<|POST-USER|>'-}}
|
||||||
|
{%- else -%}
|
||||||
|
{{-'<|PRE-ASSISTANT|>' + message['content'] + '<|POST-ASSISTANT|>' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- if add_generation_prompt -%}
|
||||||
|
{{-'<|PRE-ASSISTANT-GENERATE|>'-}}
|
||||||
|
{%- endif -%}
|
||||||
|
"""
|
||||||
|
|
||||||
|
if 'context' in params and '<|system-message|>' in params['context']:
|
||||||
|
pre_system = params['context'].split('<|system-message|>')[0]
|
||||||
|
post_system = params['context'].split('<|system-message|>')[1]
|
||||||
|
else:
|
||||||
|
pre_system = ''
|
||||||
|
post_system = ''
|
||||||
|
|
||||||
|
pre_user = params['turn_template'].split('<|user-message|>')[0].replace('<|user|>', params['user'])
|
||||||
|
post_user = params['turn_template'].split('<|user-message|>')[1].split('<|bot|>')[0]
|
||||||
|
|
||||||
|
pre_assistant = '<|bot|>' + params['turn_template'].split('<|bot-message|>')[0].split('<|bot|>')[1]
|
||||||
|
pre_assistant = pre_assistant.replace('<|bot|>', params['bot'])
|
||||||
|
post_assistant = params['turn_template'].split('<|bot-message|>')[1]
|
||||||
|
|
||||||
|
def preprocess(string):
|
||||||
|
return string.replace('\n', '\\n').replace('\'', '\\\'')
|
||||||
|
|
||||||
|
pre_system = preprocess(pre_system)
|
||||||
|
post_system = preprocess(post_system)
|
||||||
|
pre_user = preprocess(pre_user)
|
||||||
|
post_user = preprocess(post_user)
|
||||||
|
pre_assistant = preprocess(pre_assistant)
|
||||||
|
post_assistant = preprocess(post_assistant)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
'\n',
|
||||||
|
repr(pre_system) + '\n',
|
||||||
|
repr(post_system) + '\n',
|
||||||
|
repr(pre_user) + '\n',
|
||||||
|
repr(post_user) + '\n',
|
||||||
|
repr(pre_assistant) + '\n',
|
||||||
|
repr(post_assistant) + '\n',
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if nvidia-smi returned an error
|
result = MASTER_TEMPLATE
|
||||||
if result.returncode != 0:
|
if 'system_message' in params:
|
||||||
return -1
|
result = result.replace('<|SYSTEM-MESSAGE|>', preprocess(params['system_message']))
|
||||||
|
|
||||||
# Parse the output for memory usage patterns
|
|
||||||
output = result.stdout
|
|
||||||
|
|
||||||
# Find memory usage like "XXXXMiB / YYYYMiB"
|
|
||||||
# Captures used and total memory for each GPU
|
|
||||||
matches = re.findall(r"(\d+)\s*MiB\s*/\s*(\d+)\s*MiB", output)
|
|
||||||
|
|
||||||
if not matches:
|
|
||||||
# No GPUs found in expected format
|
|
||||||
return 0
|
|
||||||
|
|
||||||
total_vram_mib = 0
|
|
||||||
total_free_vram_mib = 0
|
|
||||||
|
|
||||||
for used_mem_str, total_mem_str in matches:
|
|
||||||
try:
|
|
||||||
used_mib = int(used_mem_str)
|
|
||||||
total_mib = int(total_mem_str)
|
|
||||||
total_vram_mib += total_mib
|
|
||||||
total_free_vram_mib += (total_mib - used_mib)
|
|
||||||
except ValueError:
|
|
||||||
# Skip malformed entries
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Return either free or total VRAM based on the flag
|
|
||||||
return total_free_vram_mib if return_free else total_vram_mib
|
|
||||||
|
|
||||||
except FileNotFoundError:
|
|
||||||
# nvidia-smi not found (likely no NVIDIA drivers installed)
|
|
||||||
return -1
|
|
||||||
except Exception:
|
|
||||||
# Handle any other unexpected exceptions
|
|
||||||
return -1
|
|
||||||
|
|
||||||
|
|
||||||
def update_gpu_layers_and_vram(loader, model, gpu_layers, ctx_size, cache_type, auto_adjust=False, for_ui=True):
|
|
||||||
"""
|
|
||||||
Unified function to handle GPU layers and VRAM updates.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
for_ui: If True, returns Gradio updates. If False, returns raw values.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- If for_ui=True: (vram_info_update, gpu_layers_update) or just vram_info_update
|
|
||||||
- If for_ui=False: (vram_usage, adjusted_layers) or just vram_usage
|
|
||||||
"""
|
|
||||||
if loader != 'llama.cpp' or model in ["None", None] or not model.endswith(".gguf"):
|
|
||||||
vram_info = "<div id=\"vram-info\"'>Estimated VRAM to load the model:</div>"
|
|
||||||
if for_ui:
|
|
||||||
return (vram_info, gr.update()) if auto_adjust else vram_info
|
|
||||||
else:
|
else:
|
||||||
return (0, gpu_layers) if auto_adjust else 0
|
result = result.replace('<|SYSTEM-MESSAGE|>', '')
|
||||||
|
|
||||||
# Get model settings including user preferences
|
result = result.replace('<|PRE-SYSTEM|>', pre_system)
|
||||||
model_settings = get_model_metadata(model)
|
result = result.replace('<|POST-SYSTEM|>', post_system)
|
||||||
|
result = result.replace('<|PRE-USER|>', pre_user)
|
||||||
|
result = result.replace('<|POST-USER|>', post_user)
|
||||||
|
result = result.replace('<|PRE-ASSISTANT|>', pre_assistant)
|
||||||
|
result = result.replace('<|PRE-ASSISTANT-GENERATE|>', pre_assistant.rstrip(' '))
|
||||||
|
result = result.replace('<|POST-ASSISTANT|>', post_assistant)
|
||||||
|
|
||||||
current_layers = gpu_layers
|
result = result.strip()
|
||||||
max_layers = model_settings.get('max_gpu_layers', 256)
|
|
||||||
|
|
||||||
if auto_adjust:
|
return result
|
||||||
# Check if this is a user-saved setting
|
|
||||||
user_config = shared.user_config
|
|
||||||
model_regex = Path(model).name + '$'
|
|
||||||
has_user_setting = model_regex in user_config and 'gpu_layers' in user_config[model_regex]
|
|
||||||
|
|
||||||
if not has_user_setting:
|
|
||||||
# No user setting, auto-adjust from the maximum
|
|
||||||
current_layers = max_layers # Start from max
|
|
||||||
|
|
||||||
# Auto-adjust based on available/total VRAM
|
|
||||||
# If a model is loaded and it's for the UI, use the total VRAM to avoid confusion
|
|
||||||
return_free = False if (for_ui and shared.model_name not in [None, 'None']) else True
|
|
||||||
available_vram = get_nvidia_vram(return_free=return_free)
|
|
||||||
if available_vram > 0:
|
|
||||||
tolerance = 577
|
|
||||||
while current_layers > 0 and estimate_vram(model, current_layers, ctx_size, cache_type) > available_vram - tolerance:
|
|
||||||
current_layers -= 1
|
|
||||||
|
|
||||||
# Calculate VRAM with current layers
|
|
||||||
vram_usage = estimate_vram(model, current_layers, ctx_size, cache_type)
|
|
||||||
|
|
||||||
if for_ui:
|
|
||||||
vram_info = f"<div id=\"vram-info\"'>Estimated VRAM to load the model: <span class=\"value\">{vram_usage:.0f} MiB</span></div>"
|
|
||||||
if auto_adjust:
|
|
||||||
return vram_info, gr.update(value=current_layers, maximum=max_layers)
|
|
||||||
else:
|
|
||||||
return vram_info
|
|
||||||
else:
|
|
||||||
if auto_adjust:
|
|
||||||
return vram_usage, current_layers
|
|
||||||
else:
|
|
||||||
return vram_usage
|
|
||||||
|
|
|
||||||
28
modules/paths.py
Normal file
28
modules/paths.py
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_user_data_dir():
|
||||||
|
"""
|
||||||
|
Resolve the user_data directory path. Order of precedence:
|
||||||
|
1. --user-data-dir CLI flag (pre-parsed from sys.argv before argparse)
|
||||||
|
2. In --portable mode, prefer ../user_data if it exists
|
||||||
|
3. Default: 'user_data'
|
||||||
|
"""
|
||||||
|
script_dir = Path(__file__).resolve().parent.parent
|
||||||
|
|
||||||
|
# Check sys.argv for --user-data-dir before argparse runs
|
||||||
|
for i, arg in enumerate(sys.argv):
|
||||||
|
if arg == '--user-data-dir' and i + 1 < len(sys.argv):
|
||||||
|
return Path(sys.argv[i + 1])
|
||||||
|
elif arg.startswith('--user-data-dir='):
|
||||||
|
return Path(arg.split('=', 1)[1])
|
||||||
|
|
||||||
|
# In portable mode, prefer ../user_data if it exists
|
||||||
|
is_portable = '--portable' in sys.argv
|
||||||
|
if is_portable:
|
||||||
|
parent_path = script_dir.parent / 'user_data'
|
||||||
|
if parent_path.exists():
|
||||||
|
return parent_path
|
||||||
|
|
||||||
|
return Path('user_data')
|
||||||
|
|
@ -9,17 +9,17 @@ from modules.loaders import loaders_samplers
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
|
|
||||||
def default_preset():
|
default_preset_values = {
|
||||||
result = {
|
|
||||||
'temperature': 1,
|
'temperature': 1,
|
||||||
'dynatemp_low': 1,
|
'dynatemp_low': 1,
|
||||||
'dynatemp_high': 1,
|
'dynatemp_high': 1,
|
||||||
'dynatemp_exponent': 1,
|
'dynatemp_exponent': 1,
|
||||||
'smoothing_factor': 0,
|
'smoothing_factor': 0,
|
||||||
'smoothing_curve': 1,
|
'smoothing_curve': 1,
|
||||||
'min_p': 0,
|
|
||||||
'top_p': 1,
|
'top_p': 1,
|
||||||
'top_k': 0,
|
'top_k': 0,
|
||||||
|
'min_p': 0,
|
||||||
|
'top_n_sigma': 0,
|
||||||
'typical_p': 1,
|
'typical_p': 1,
|
||||||
'xtc_threshold': 0.1,
|
'xtc_threshold': 0.1,
|
||||||
'xtc_probability': 0,
|
'xtc_probability': 0,
|
||||||
|
|
@ -27,7 +27,8 @@ def default_preset():
|
||||||
'eta_cutoff': 0,
|
'eta_cutoff': 0,
|
||||||
'tfs': 1,
|
'tfs': 1,
|
||||||
'top_a': 0,
|
'top_a': 0,
|
||||||
'top_n_sigma': 0,
|
'adaptive_target': 0,
|
||||||
|
'adaptive_decay': 0.9,
|
||||||
'dry_multiplier': 0,
|
'dry_multiplier': 0,
|
||||||
'dry_allowed_length': 2,
|
'dry_allowed_length': 2,
|
||||||
'dry_base': 1.75,
|
'dry_base': 1.75,
|
||||||
|
|
@ -45,9 +46,13 @@ def default_preset():
|
||||||
'do_sample': True,
|
'do_sample': True,
|
||||||
'dynamic_temperature': False,
|
'dynamic_temperature': False,
|
||||||
'temperature_last': False,
|
'temperature_last': False,
|
||||||
'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ndry\ntop_n_sigma\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat\nxtc\nencoder_repetition_penalty\nno_repeat_ngram',
|
'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ndry\ntop_n_sigma\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nadaptive_p\nmirostat\nxtc\nencoder_repetition_penalty\nno_repeat_ngram',
|
||||||
'dry_sequence_breakers': '"\\n", ":", "\\"", "*"',
|
'dry_sequence_breakers': '"\\n", ":", "\\"", "*"',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def default_preset():
|
||||||
|
result = dict(default_preset_values)
|
||||||
|
|
||||||
if shared.args.portable:
|
if shared.args.portable:
|
||||||
samplers = result['sampler_priority'].split('\n')
|
samplers = result['sampler_priority'].split('\n')
|
||||||
|
|
@ -64,7 +69,7 @@ def presets_params():
|
||||||
def load_preset(name, verbose=False):
|
def load_preset(name, verbose=False):
|
||||||
generate_params = default_preset()
|
generate_params = default_preset()
|
||||||
if name not in ['None', None, '']:
|
if name not in ['None', None, '']:
|
||||||
path = Path(f'user_data/presets/{name}.yaml')
|
path = shared.user_data_dir / 'presets' / f'{name}.yaml'
|
||||||
if path.exists():
|
if path.exists():
|
||||||
with open(path, 'r') as infile:
|
with open(path, 'r') as infile:
|
||||||
preset = yaml.safe_load(infile)
|
preset = yaml.safe_load(infile)
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ def load_prompt(fname):
|
||||||
if not fname:
|
if not fname:
|
||||||
# Create new file
|
# Create new file
|
||||||
new_name = utils.current_time()
|
new_name = utils.current_time()
|
||||||
prompt_path = Path("user_data/logs/notebook") / f"{new_name}.txt"
|
prompt_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt"
|
||||||
prompt_path.parent.mkdir(parents=True, exist_ok=True)
|
prompt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
initial_content = "In this story,"
|
initial_content = "In this story,"
|
||||||
prompt_path.write_text(initial_content, encoding='utf-8')
|
prompt_path.write_text(initial_content, encoding='utf-8')
|
||||||
|
|
@ -18,7 +18,7 @@ def load_prompt(fname):
|
||||||
|
|
||||||
return initial_content
|
return initial_content
|
||||||
|
|
||||||
file_path = Path(f'user_data/logs/notebook/{fname}.txt')
|
file_path = shared.user_data_dir / 'logs' / 'notebook' / f'{fname}.txt'
|
||||||
if file_path.exists():
|
if file_path.exists():
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
text = f.read()
|
text = f.read()
|
||||||
|
|
@ -33,5 +33,5 @@ def count_tokens(text):
|
||||||
try:
|
try:
|
||||||
tokens = get_encoded_length(text)
|
tokens = get_encoded_length(text)
|
||||||
return str(tokens)
|
return str(tokens)
|
||||||
except:
|
except Exception:
|
||||||
return '0'
|
return '0'
|
||||||
|
|
|
||||||
94
modules/reasoning.py
Normal file
94
modules/reasoning.py
Normal file
|
|
@ -0,0 +1,94 @@
|
||||||
|
import html as html_module
|
||||||
|
|
||||||
|
# Thinking block format definitions: (start_tag, end_tag, content_start_tag)
|
||||||
|
# Use None for start_tag to match from beginning (end-only formats should be listed last)
|
||||||
|
THINKING_FORMATS = [
|
||||||
|
('<think>', '</think>', None),
|
||||||
|
('<|channel|>analysis<|message|>', '<|end|>', '<|channel|>final<|message|>'),
|
||||||
|
('<|channel|>commentary<|message|>', '<|end|>', '<|channel|>final<|message|>'),
|
||||||
|
('<seed:think>', '</seed:think>', None),
|
||||||
|
('<|think|>', '<|end|>', '<|content|>'), # Solar Open
|
||||||
|
# ('Thinking Process:', '</think>', None), # Qwen3.5 verbose thinking outside tags -- removed: too prone to false positives in streaming
|
||||||
|
(None, '</think>', None), # End-only variant (e.g., Qwen3-next)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def extract_reasoning(text, html_escaped=False):
|
||||||
|
"""Extract reasoning/thinking blocks from the beginning of a string.
|
||||||
|
|
||||||
|
When html_escaped=True, tags are HTML-escaped before searching
|
||||||
|
(for use on already-escaped UI strings).
|
||||||
|
|
||||||
|
Returns (reasoning_content, final_content) where reasoning_content is
|
||||||
|
None if no thinking block is found.
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return None, text
|
||||||
|
|
||||||
|
esc = html_module.escape if html_escaped else lambda s: s
|
||||||
|
|
||||||
|
for start_tag, end_tag, content_tag in THINKING_FORMATS:
|
||||||
|
end_esc = esc(end_tag)
|
||||||
|
content_esc = esc(content_tag) if content_tag else None
|
||||||
|
|
||||||
|
if start_tag is None:
|
||||||
|
# End-only format: require end tag, start from beginning
|
||||||
|
end_pos = text.find(end_esc)
|
||||||
|
if end_pos == -1:
|
||||||
|
continue
|
||||||
|
thought_start = 0
|
||||||
|
else:
|
||||||
|
# Normal format: require start tag
|
||||||
|
start_esc = esc(start_tag)
|
||||||
|
start_pos = text.find(start_esc)
|
||||||
|
if start_pos == -1:
|
||||||
|
# During streaming, the start tag may be arriving partially.
|
||||||
|
# If the text is a prefix of a start tag, return empty content
|
||||||
|
# to prevent the partial tag from leaking.
|
||||||
|
stripped = text.strip()
|
||||||
|
if stripped and start_esc.startswith(stripped):
|
||||||
|
return '', ''
|
||||||
|
continue
|
||||||
|
thought_start = start_pos + len(start_esc)
|
||||||
|
end_pos = text.find(end_esc, thought_start)
|
||||||
|
|
||||||
|
if end_pos == -1:
|
||||||
|
# End tag missing - check if content tag can serve as fallback
|
||||||
|
if content_esc:
|
||||||
|
content_pos = text.find(content_esc, thought_start)
|
||||||
|
if content_pos != -1:
|
||||||
|
thought_end = content_pos
|
||||||
|
content_start = content_pos + len(content_esc)
|
||||||
|
else:
|
||||||
|
thought_end = len(text)
|
||||||
|
content_start = len(text)
|
||||||
|
else:
|
||||||
|
thought_end = len(text)
|
||||||
|
content_start = len(text)
|
||||||
|
else:
|
||||||
|
thought_end = end_pos
|
||||||
|
if content_esc:
|
||||||
|
content_pos = text.find(content_esc, end_pos)
|
||||||
|
if content_pos != -1:
|
||||||
|
content_start = content_pos + len(content_esc)
|
||||||
|
else:
|
||||||
|
# Content tag expected but not yet present (e.g. partial
|
||||||
|
# streaming) — suppress intermediate tags between end_tag
|
||||||
|
# and content_tag so they don't leak as content.
|
||||||
|
content_start = len(text)
|
||||||
|
else:
|
||||||
|
content_start = end_pos + len(end_esc)
|
||||||
|
|
||||||
|
return text[thought_start:thought_end], text[content_start:]
|
||||||
|
|
||||||
|
# Handle standalone GPT-OSS final channel marker without a preceding
|
||||||
|
# analysis/commentary block (the model skipped thinking entirely).
|
||||||
|
for marker in ['<|start|>assistant<|channel|>final<|message|>', '<|channel|>final<|message|>']:
|
||||||
|
marker_esc = esc(marker)
|
||||||
|
pos = text.find(marker_esc)
|
||||||
|
if pos != -1:
|
||||||
|
before = text[:pos].strip()
|
||||||
|
after = text[pos + len(marker_esc):]
|
||||||
|
return (before if before else None), after
|
||||||
|
|
||||||
|
return None, text
|
||||||
|
|
@ -235,6 +235,73 @@ class TopNSigmaLogitsWarper(LogitsProcessor):
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptivePLogitsWarper(LogitsProcessor):
|
||||||
|
'''
|
||||||
|
Adaptive-p sampling. A stateful sampler that favors tokens near a target
|
||||||
|
probability, using an EMA-based control loop to adapt over time.
|
||||||
|
|
||||||
|
Matches the llama.cpp implementation from PR #17927.
|
||||||
|
'''
|
||||||
|
|
||||||
|
DISTRIBUTION_WIDTH = 0.3
|
||||||
|
PEAK_LOGIT_VALUE = 5.0
|
||||||
|
SHARPNESS = 10.0
|
||||||
|
INV_WIDTH = 1.0 / DISTRIBUTION_WIDTH
|
||||||
|
|
||||||
|
def __init__(self, adaptive_target, adaptive_decay, filter_value=-float("Inf"), min_tokens_to_keep=1):
|
||||||
|
self.target = adaptive_target
|
||||||
|
self.decay = min(adaptive_decay, 0.99)
|
||||||
|
self.filter_value = filter_value
|
||||||
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
|
# Initialize EMA at equilibrium (as if target was already achieved)
|
||||||
|
if self.decay < 1.0:
|
||||||
|
self.weighted_sum = self.target / (1.0 - self.decay)
|
||||||
|
self.total_weight = 1.0 / (1.0 - self.decay)
|
||||||
|
else:
|
||||||
|
self.weighted_sum = 0.0
|
||||||
|
self.total_weight = 0.0
|
||||||
|
|
||||||
|
def __call__(self, input_ids, scores):
|
||||||
|
logits = scores[0]
|
||||||
|
|
||||||
|
# Compute original probabilities (before transform)
|
||||||
|
probs = torch.softmax(logits, dim=-1)
|
||||||
|
|
||||||
|
# Compute adapted target using proportional control on the EMA
|
||||||
|
if self.total_weight > 0:
|
||||||
|
ema_avg = self.weighted_sum / self.total_weight
|
||||||
|
else:
|
||||||
|
ema_avg = self.target
|
||||||
|
|
||||||
|
adapted_target = max(0.0, min(1.0, 2.0 * self.target - ema_avg))
|
||||||
|
|
||||||
|
# Adaptive probability transform:
|
||||||
|
# quadratic near target for fine differentiation, transitioning
|
||||||
|
# to linear decay in the tails for proper suppression after softmax
|
||||||
|
dist = torch.abs((probs - adapted_target) * self.INV_WIDTH)
|
||||||
|
new_logits = self.PEAK_LOGIT_VALUE - self.SHARPNESS * dist * dist / (1.0 + dist)
|
||||||
|
|
||||||
|
# Preserve already-masked tokens (-inf logits from prior samplers)
|
||||||
|
new_logits = torch.where(torch.isfinite(logits), new_logits, logits)
|
||||||
|
|
||||||
|
# Softmax and sample from the transformed distribution
|
||||||
|
new_probs = torch.softmax(new_logits, dim=-1)
|
||||||
|
selected = torch.multinomial(new_probs, num_samples=1, replacement=True)
|
||||||
|
|
||||||
|
# Update EMA with the original probability of the selected token
|
||||||
|
original_prob = probs[selected[0]].item()
|
||||||
|
self.weighted_sum = original_prob + self.decay * self.weighted_sum
|
||||||
|
self.total_weight = 1.0 + self.decay * self.total_weight
|
||||||
|
|
||||||
|
# Mask all tokens except the selected one
|
||||||
|
indices_to_remove = torch.ones_like(scores[0], dtype=torch.bool)
|
||||||
|
indices_to_remove[selected[0]] = False
|
||||||
|
indices_to_remove = indices_to_remove.unsqueeze(0)
|
||||||
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
# Exclude Top Choices (XTC)
|
# Exclude Top Choices (XTC)
|
||||||
class XTCLogitsWarper(LogitsProcessor):
|
class XTCLogitsWarper(LogitsProcessor):
|
||||||
def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")):
|
def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")):
|
||||||
|
|
@ -575,6 +642,15 @@ def get_logits_processor_patch(self, **kwargs):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if generation_config.adaptive_target is not None and generation_config.adaptive_target > 0.0:
|
||||||
|
warpers_to_add.append(
|
||||||
|
AdaptivePLogitsWarper(
|
||||||
|
adaptive_target=generation_config.adaptive_target,
|
||||||
|
adaptive_decay=generation_config.adaptive_decay,
|
||||||
|
min_tokens_to_keep=min_tokens_to_keep
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if generation_config.xtc_probability is not None and generation_config.xtc_probability > 0:
|
if generation_config.xtc_probability is not None and generation_config.xtc_probability > 0:
|
||||||
warpers_to_add.append(
|
warpers_to_add.append(
|
||||||
XTCLogitsWarper(
|
XTCLogitsWarper(
|
||||||
|
|
@ -640,6 +716,7 @@ def get_logits_processor_patch(self, **kwargs):
|
||||||
'TemperatureLogitsWarperCustom': 'temperature',
|
'TemperatureLogitsWarperCustom': 'temperature',
|
||||||
'TopALogitsWarper': 'top_a',
|
'TopALogitsWarper': 'top_a',
|
||||||
'TopNSigmaLogitsWarper': 'top_n_sigma',
|
'TopNSigmaLogitsWarper': 'top_n_sigma',
|
||||||
|
'AdaptivePLogitsWarper': 'adaptive_p',
|
||||||
'TopKLogitsWarper': 'top_k',
|
'TopKLogitsWarper': 'top_k',
|
||||||
'TopPLogitsWarper': 'top_p',
|
'TopPLogitsWarper': 'top_p',
|
||||||
'TypicalLogitsWarper': 'typical_p',
|
'TypicalLogitsWarper': 'typical_p',
|
||||||
|
|
@ -688,6 +765,8 @@ def generation_config_init_patch(self, **kwargs):
|
||||||
self.tfs = kwargs.pop("tfs", 1.0)
|
self.tfs = kwargs.pop("tfs", 1.0)
|
||||||
self.top_a = kwargs.pop("top_a", 0.0)
|
self.top_a = kwargs.pop("top_a", 0.0)
|
||||||
self.top_n_sigma = kwargs.pop("top_n_sigma", 0.0)
|
self.top_n_sigma = kwargs.pop("top_n_sigma", 0.0)
|
||||||
|
self.adaptive_target = kwargs.pop("adaptive_target", 0.0)
|
||||||
|
self.adaptive_decay = kwargs.pop("adaptive_decay", 0.9)
|
||||||
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
|
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
|
||||||
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
|
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
|
||||||
self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
|
self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
|
||||||
|
|
@ -701,7 +780,7 @@ def generation_config_init_patch(self, **kwargs):
|
||||||
self.xtc_threshold = kwargs.pop("xtc_threshold", 0.1)
|
self.xtc_threshold = kwargs.pop("xtc_threshold", 0.1)
|
||||||
self.xtc_probability = kwargs.pop("xtc_probability", 0)
|
self.xtc_probability = kwargs.pop("xtc_probability", 0)
|
||||||
self.temperature_last = kwargs.pop("temperature_last", False)
|
self.temperature_last = kwargs.pop("temperature_last", False)
|
||||||
self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_n_sigma', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram'])
|
self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_n_sigma', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'adaptive_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram'])
|
||||||
|
|
||||||
|
|
||||||
def hijack_samplers():
|
def hijack_samplers():
|
||||||
|
|
|
||||||
|
|
@ -58,9 +58,8 @@ class SaneListIndentProcessor(ListIndentProcessor):
|
||||||
def test(self, parent: etree.Element, block: str) -> bool:
|
def test(self, parent: etree.Element, block: str) -> bool:
|
||||||
return block.startswith(' ' * MIN_NESTED_LIST_INDENT) and \
|
return block.startswith(' ' * MIN_NESTED_LIST_INDENT) and \
|
||||||
not self.parser.state.isstate('detabbed') and \
|
not self.parser.state.isstate('detabbed') and \
|
||||||
(parent.tag in self.ITEM_TYPES or
|
(parent.tag in self.ITEM_TYPES or (len(parent) and parent[-1] is not None and (parent[-1].tag in
|
||||||
(len(parent) and parent[-1] is not None and
|
self.LIST_TYPES)))
|
||||||
(parent[-1].tag in self.LIST_TYPES)))
|
|
||||||
|
|
||||||
def get_level(self, parent: etree.Element, block: str) -> tuple[int, etree.Element]:
|
def get_level(self, parent: etree.Element, block: str) -> tuple[int, etree.Element]:
|
||||||
""" Get level of indentation based on list level. """
|
""" Get level of indentation based on list level. """
|
||||||
|
|
@ -79,8 +78,7 @@ class SaneListIndentProcessor(ListIndentProcessor):
|
||||||
# Step through children of tree to find matching indent level.
|
# Step through children of tree to find matching indent level.
|
||||||
while indent_level > level:
|
while indent_level > level:
|
||||||
child = self.lastChild(parent)
|
child = self.lastChild(parent)
|
||||||
if (child is not None and
|
if child is not None and (child.tag in self.LIST_TYPES or child.tag in self.ITEM_TYPES):
|
||||||
(child.tag in self.LIST_TYPES or child.tag in self.ITEM_TYPES)):
|
|
||||||
if child.tag in self.LIST_TYPES:
|
if child.tag in self.LIST_TYPES:
|
||||||
level += 1
|
level += 1
|
||||||
parent = child
|
parent = child
|
||||||
|
|
@ -124,16 +122,14 @@ class SaneOListProcessor(OListProcessor):
|
||||||
|
|
||||||
def __init__(self, parser: blockparser.BlockParser):
|
def __init__(self, parser: blockparser.BlockParser):
|
||||||
super().__init__(parser)
|
super().__init__(parser)
|
||||||
# This restriction stems from the 'CodeBlockProcessor' class,
|
max_list_start_indent = self.tab_length
|
||||||
# which automatically matches blocks with an indent = self.tab_length
|
|
||||||
max_list_start_indent = self.tab_length - 1
|
|
||||||
# Detect an item (e.g., `1. item`)
|
# Detect an item (e.g., `1. item`)
|
||||||
self.RE = re.compile(r'^[ ]{0,%d}[\*_]{0,2}\d+\.[ ]+(.*)' % max_list_start_indent)
|
self.RE = re.compile(r'^[ ]{0,%d}[\*_]{0,2}\d+\.[ ]+(.*)' % max_list_start_indent)
|
||||||
# Detect items on secondary lines. they can be of either list type.
|
# Detect items on secondary lines. they can be of either list type.
|
||||||
self.CHILD_RE = re.compile(r'^[ ]{0,%d}([\*_]{0,2})((\d+\.))[ ]+(.*)' % (MIN_NESTED_LIST_INDENT - 1))
|
self.CHILD_RE = re.compile(r'^[ ]{0,%d}([\*_]{0,2})((\d+\.))[ ]+(.*)' % (MIN_NESTED_LIST_INDENT - 1))
|
||||||
# Detect indented (nested) items of either type
|
# Detect indented (nested) items of either type
|
||||||
self.INDENT_RE = re.compile(r'^[ ]{%d,%d}[\*_]{0,2}((\d+\.)|[*+-])[ ]+.*' %
|
self.INDENT_RE = re.compile(r'^[ ]{%d,%d}[\*_]{0,2}((\d+\.)|[*+-])[ ]+.*' %
|
||||||
(MIN_NESTED_LIST_INDENT, self.tab_length * 2 - 1))
|
(MIN_NESTED_LIST_INDENT, self.tab_length * 2))
|
||||||
|
|
||||||
def run(self, parent: etree.Element, blocks: list[str]) -> None:
|
def run(self, parent: etree.Element, blocks: list[str]) -> None:
|
||||||
# Check for multiple items in one block.
|
# Check for multiple items in one block.
|
||||||
|
|
@ -242,7 +238,7 @@ class SaneUListProcessor(SaneOListProcessor):
|
||||||
def __init__(self, parser: blockparser.BlockParser):
|
def __init__(self, parser: blockparser.BlockParser):
|
||||||
super().__init__(parser)
|
super().__init__(parser)
|
||||||
# Detect an item (e.g., `- item` or `+ item` or `* item`).
|
# Detect an item (e.g., `- item` or `+ item` or `* item`).
|
||||||
max_list_start_indent = self.tab_length - 1
|
max_list_start_indent = self.tab_length
|
||||||
self.RE = re.compile(r'^[ ]{0,%d}[*+-][ ]+(.*)' % max_list_start_indent)
|
self.RE = re.compile(r'^[ ]{0,%d}[*+-][ ]+(.*)' % max_list_start_indent)
|
||||||
self.CHILD_RE = re.compile(r'^[ ]{0,%d}(([*+-]))[ ]+(.*)' % (MIN_NESTED_LIST_INDENT - 1))
|
self.CHILD_RE = re.compile(r'^[ ]{0,%d}(([*+-]))[ ]+(.*)' % (MIN_NESTED_LIST_INDENT - 1))
|
||||||
|
|
||||||
|
|
@ -275,7 +271,7 @@ class SaneParagraphProcessor(ParagraphProcessor):
|
||||||
|
|
||||||
def __init__(self, parser: BlockParser):
|
def __init__(self, parser: BlockParser):
|
||||||
super().__init__(parser)
|
super().__init__(parser)
|
||||||
max_list_start_indent = self.tab_length - 1
|
max_list_start_indent = self.tab_length
|
||||||
self.LIST_RE = re.compile(r"\s{2}\n(\s{0,%d}[\d+*-])" % max_list_start_indent)
|
self.LIST_RE = re.compile(r"\s{2}\n(\s{0,%d}[\d+*-])" % max_list_start_indent)
|
||||||
|
|
||||||
def run(self, parent: etree.Element, blocks: list[str]) -> None:
|
def run(self, parent: etree.Element, blocks: list[str]) -> None:
|
||||||
|
|
@ -331,6 +327,9 @@ class SaneListExtension(Extension):
|
||||||
md.parser.blockprocessors.register(SaneUListProcessor(md.parser), 'ulist', 30)
|
md.parser.blockprocessors.register(SaneUListProcessor(md.parser), 'ulist', 30)
|
||||||
md.parser.blockprocessors.register(SaneParagraphProcessor(md.parser), 'paragraph', 10)
|
md.parser.blockprocessors.register(SaneParagraphProcessor(md.parser), 'paragraph', 10)
|
||||||
|
|
||||||
|
# Disable uncommon indented codeblocks (as opposed to fenced codeblocks delimited by "```")
|
||||||
|
md.parser.blockprocessors.deregister('code')
|
||||||
|
|
||||||
|
|
||||||
def makeExtension(**kwargs): # pragma: no cover
|
def makeExtension(**kwargs): # pragma: no cover
|
||||||
return SaneListExtension(**kwargs)
|
return SaneListExtension(**kwargs)
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,11 @@ from pathlib import Path
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.presets import default_preset
|
from modules.paths import resolve_user_data_dir
|
||||||
|
from modules.presets import default_preset, default_preset_values
|
||||||
|
|
||||||
|
# Resolve user_data directory early (before argparse defaults are set)
|
||||||
|
user_data_dir = resolve_user_data_dir()
|
||||||
|
|
||||||
# Text model variables
|
# Text model variables
|
||||||
model = None
|
model = None
|
||||||
|
|
@ -42,11 +46,12 @@ parser = argparse.ArgumentParser(description="Text Generation Web UI", conflict_
|
||||||
|
|
||||||
# Basic settings
|
# Basic settings
|
||||||
group = parser.add_argument_group('Basic settings')
|
group = parser.add_argument_group('Basic settings')
|
||||||
group.add_argument('--multi-user', action='store_true', help='Multi-user mode. Chat histories are not saved or automatically loaded. Warning: this is likely not safe for sharing publicly.')
|
group.add_argument('--user-data-dir', type=str, default=str(user_data_dir), help='Path to the user data directory. Default: auto-detected.')
|
||||||
|
group.add_argument('--multi-user', action='store_true', help='Multi-user mode. Chat histories are not saved or automatically loaded. Best suited for small trusted teams.')
|
||||||
group.add_argument('--model', type=str, help='Name of the model to load by default.')
|
group.add_argument('--model', type=str, help='Name of the model to load by default.')
|
||||||
group.add_argument('--lora', type=str, nargs='+', help='The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.')
|
group.add_argument('--lora', type=str, nargs='+', help='The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.')
|
||||||
group.add_argument('--model-dir', type=str, default='user_data/models', help='Path to directory with all the models.')
|
group.add_argument('--model-dir', type=str, default=str(user_data_dir / 'models'), help='Path to directory with all the models.')
|
||||||
group.add_argument('--lora-dir', type=str, default='user_data/loras', help='Path to directory with all the loras.')
|
group.add_argument('--lora-dir', type=str, default=str(user_data_dir / 'loras'), help='Path to directory with all the loras.')
|
||||||
group.add_argument('--model-menu', action='store_true', help='Show a model menu in the terminal when the web UI is first launched.')
|
group.add_argument('--model-menu', action='store_true', help='Show a model menu in the terminal when the web UI is first launched.')
|
||||||
group.add_argument('--settings', type=str, help='Load the default interface settings from this yaml file. See user_data/settings-template.yaml for an example. If you create a file called user_data/settings.yaml, this file will be loaded by default without the need to use the --settings flag.')
|
group.add_argument('--settings', type=str, help='Load the default interface settings from this yaml file. See user_data/settings-template.yaml for an example. If you create a file called user_data/settings.yaml, this file will be loaded by default without the need to use the --settings flag.')
|
||||||
group.add_argument('--extensions', type=str, nargs='+', help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
|
group.add_argument('--extensions', type=str, nargs='+', help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
|
||||||
|
|
@ -56,7 +61,7 @@ group.add_argument('--idle-timeout', type=int, default=0, help='Unload model aft
|
||||||
# Image generation
|
# Image generation
|
||||||
group = parser.add_argument_group('Image model')
|
group = parser.add_argument_group('Image model')
|
||||||
group.add_argument('--image-model', type=str, help='Name of the image model to select on startup (overrides saved setting).')
|
group.add_argument('--image-model', type=str, help='Name of the image model to select on startup (overrides saved setting).')
|
||||||
group.add_argument('--image-model-dir', type=str, default='user_data/image_models', help='Path to directory with all the image models.')
|
group.add_argument('--image-model-dir', type=str, default=str(user_data_dir / 'image_models'), help='Path to directory with all the image models.')
|
||||||
group.add_argument('--image-dtype', type=str, default=None, choices=['bfloat16', 'float16'], help='Data type for image model.')
|
group.add_argument('--image-dtype', type=str, default=None, choices=['bfloat16', 'float16'], help='Data type for image model.')
|
||||||
group.add_argument('--image-attn-backend', type=str, default=None, choices=['flash_attention_2', 'sdpa'], help='Attention backend for image model.')
|
group.add_argument('--image-attn-backend', type=str, default=None, choices=['flash_attention_2', 'sdpa'], help='Attention backend for image model.')
|
||||||
group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.')
|
group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.')
|
||||||
|
|
@ -67,12 +72,12 @@ group.add_argument('--image-quant', type=str, default=None,
|
||||||
|
|
||||||
# Model loader
|
# Model loader
|
||||||
group = parser.add_argument_group('Model loader')
|
group = parser.add_argument_group('Model loader')
|
||||||
group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav2_HF, ExLlamav2, TensorRT-LLM.')
|
group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav3, TensorRT-LLM.')
|
||||||
|
|
||||||
# Cache
|
# Cache
|
||||||
group = parser.add_argument_group('Context and cache')
|
group = parser.add_argument_group('Context and cache')
|
||||||
group.add_argument('--ctx-size', '--n_ctx', '--max_seq_len', type=int, default=8192, metavar='N', help='Context size in tokens.')
|
group.add_argument('--ctx-size', '--n_ctx', '--max_seq_len', type=int, default=0, metavar='N', help='Context size in tokens. 0 = auto for llama.cpp (requires gpu-layers=-1), 8192 for other loaders.')
|
||||||
group.add_argument('--cache-type', '--cache_type', type=str, default='fp16', metavar='N', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8).')
|
group.add_argument('--cache-type', '--cache_type', type=str, default='fp16', metavar='N', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8).')
|
||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
group = parser.add_argument_group('Speculative decoding')
|
group = parser.add_argument_group('Speculative decoding')
|
||||||
|
|
@ -81,10 +86,14 @@ group.add_argument('--draft-max', type=int, default=4, help='Number of tokens to
|
||||||
group.add_argument('--gpu-layers-draft', type=int, default=256, help='Number of layers to offload to the GPU for the draft model.')
|
group.add_argument('--gpu-layers-draft', type=int, default=256, help='Number of layers to offload to the GPU for the draft model.')
|
||||||
group.add_argument('--device-draft', type=str, default=None, help='Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1')
|
group.add_argument('--device-draft', type=str, default=None, help='Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1')
|
||||||
group.add_argument('--ctx-size-draft', type=int, default=0, help='Size of the prompt context for the draft model. If 0, uses the same as the main model.')
|
group.add_argument('--ctx-size-draft', type=int, default=0, help='Size of the prompt context for the draft model. If 0, uses the same as the main model.')
|
||||||
|
group.add_argument('--spec-type', type=str, default='none', choices=['none', 'ngram-mod', 'ngram-simple', 'ngram-map-k', 'ngram-map-k4v', 'ngram-cache'], help='Draftless speculative decoding type. Recommended: ngram-mod.')
|
||||||
|
group.add_argument('--spec-ngram-size-n', type=int, default=24, help='N-gram lookup size for ngram speculative decoding.')
|
||||||
|
group.add_argument('--spec-ngram-size-m', type=int, default=48, help='Draft n-gram size for ngram speculative decoding.')
|
||||||
|
group.add_argument('--spec-ngram-min-hits', type=int, default=1, help='Minimum n-gram hits for ngram-map speculative decoding.')
|
||||||
|
|
||||||
# llama.cpp
|
# llama.cpp
|
||||||
group = parser.add_argument_group('llama.cpp')
|
group = parser.add_argument_group('llama.cpp')
|
||||||
group.add_argument('--gpu-layers', '--n-gpu-layers', type=int, default=256, metavar='N', help='Number of layers to offload to the GPU.')
|
group.add_argument('--gpu-layers', '--n-gpu-layers', type=int, default=-1, metavar='N', help='Number of layers to offload to the GPU. -1 = auto.')
|
||||||
group.add_argument('--cpu-moe', action='store_true', help='Move the experts to the CPU (for MoE models).')
|
group.add_argument('--cpu-moe', action='store_true', help='Move the experts to the CPU (for MoE models).')
|
||||||
group.add_argument('--mmproj', type=str, default=None, help='Path to the mmproj file for vision models.')
|
group.add_argument('--mmproj', type=str, default=None, help='Path to the mmproj file for vision models.')
|
||||||
group.add_argument('--streaming-llm', action='store_true', help='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
|
group.add_argument('--streaming-llm', action='store_true', help='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
|
||||||
|
|
@ -98,6 +107,8 @@ group.add_argument('--ubatch-size', type=int, default=1024, help='Maximum number
|
||||||
group.add_argument('--threads', type=int, default=0, help='Number of threads to use.')
|
group.add_argument('--threads', type=int, default=0, help='Number of threads to use.')
|
||||||
group.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.')
|
group.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.')
|
||||||
group.add_argument('--numa', action='store_true', help='Activate NUMA task allocation for llama.cpp.')
|
group.add_argument('--numa', action='store_true', help='Activate NUMA task allocation for llama.cpp.')
|
||||||
|
group.add_argument('--parallel', type=int, default=1, help='Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set ctx_size to 32768.')
|
||||||
|
group.add_argument('--fit-target', type=str, default='512', help='Target VRAM margin per device for auto GPU layers, comma-separated list of values in MiB. A single value is broadcast across all devices.')
|
||||||
group.add_argument('--extra-flags', type=str, default=None, help='Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"')
|
group.add_argument('--extra-flags', type=str, default=None, help='Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"')
|
||||||
|
|
||||||
# Transformers/Accelerate
|
# Transformers/Accelerate
|
||||||
|
|
@ -105,7 +116,7 @@ group = parser.add_argument_group('Transformers/Accelerate')
|
||||||
group.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
|
group.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
|
||||||
group.add_argument('--cpu-memory', type=float, default=0, help='Maximum CPU memory in GiB. Use this for CPU offloading.')
|
group.add_argument('--cpu-memory', type=float, default=0, help='Maximum CPU memory in GiB. Use this for CPU offloading.')
|
||||||
group.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
|
group.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
|
||||||
group.add_argument('--disk-cache-dir', type=str, default='user_data/cache', help='Directory to save the disk cache to. Defaults to "user_data/cache".')
|
group.add_argument('--disk-cache-dir', type=str, default=str(user_data_dir / 'cache'), help='Directory to save the disk cache to.')
|
||||||
group.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision (using bitsandbytes).')
|
group.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision (using bitsandbytes).')
|
||||||
group.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
group.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
||||||
group.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces VRAM usage slightly, but it comes at a performance cost.')
|
group.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces VRAM usage slightly, but it comes at a performance cost.')
|
||||||
|
|
@ -123,34 +134,10 @@ group.add_argument('--quant_type', type=str, default='nf4', help='quant_type for
|
||||||
|
|
||||||
# ExLlamaV3
|
# ExLlamaV3
|
||||||
group = parser.add_argument_group('ExLlamaV3')
|
group = parser.add_argument_group('ExLlamaV3')
|
||||||
|
group.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.')
|
||||||
group.add_argument('--enable-tp', '--enable_tp', action='store_true', help='Enable Tensor Parallelism (TP) to split the model across GPUs.')
|
group.add_argument('--enable-tp', '--enable_tp', action='store_true', help='Enable Tensor Parallelism (TP) to split the model across GPUs.')
|
||||||
group.add_argument('--tp-backend', type=str, default='native', help='The backend for tensor parallelism. Valid options: native, nccl. Default: native.')
|
group.add_argument('--tp-backend', type=str, default='native', help='The backend for tensor parallelism. Valid options: native, nccl. Default: native.')
|
||||||
|
group.add_argument('--cfg-cache', action='store_true', help='Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.')
|
||||||
# ExLlamaV2
|
|
||||||
group = parser.add_argument_group('ExLlamaV2')
|
|
||||||
group.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.')
|
|
||||||
group.add_argument('--autosplit', action='store_true', help='Autosplit the model tensors across the available GPUs. This causes --gpu-split to be ignored.')
|
|
||||||
group.add_argument('--cfg-cache', action='store_true', help='ExLlamav2_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.')
|
|
||||||
group.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.')
|
|
||||||
group.add_argument('--no_xformers', action='store_true', help='Force xformers to not be used.')
|
|
||||||
group.add_argument('--no_sdpa', action='store_true', help='Force Torch SDPA to not be used.')
|
|
||||||
group.add_argument('--num_experts_per_token', type=int, default=2, metavar='N', help='Number of experts to use for generation. Applies to MoE models like Mixtral.')
|
|
||||||
|
|
||||||
# TensorRT-LLM
|
|
||||||
group = parser.add_argument_group('TensorRT-LLM')
|
|
||||||
group.add_argument('--cpp-runner', action='store_true', help='Use the ModelRunnerCpp runner, which is faster than the default ModelRunner but doesn\'t support streaming yet.')
|
|
||||||
|
|
||||||
# DeepSpeed
|
|
||||||
group = parser.add_argument_group('DeepSpeed')
|
|
||||||
group.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
|
|
||||||
group.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.')
|
|
||||||
group.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
|
|
||||||
|
|
||||||
# RoPE
|
|
||||||
group = parser.add_argument_group('RoPE')
|
|
||||||
group.add_argument('--alpha_value', type=float, default=1, help='Positional embeddings alpha factor for NTK RoPE scaling. Use either this or compress_pos_emb, not both.')
|
|
||||||
group.add_argument('--rope_freq_base', type=int, default=0, help='If greater than 0, will be used instead of alpha_value. Those two are related by rope_freq_base = 10000 * alpha_value ^ (64 / 63).')
|
|
||||||
group.add_argument('--compress_pos_emb', type=int, default=1, help="Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.")
|
|
||||||
|
|
||||||
# Gradio
|
# Gradio
|
||||||
group = parser.add_argument_group('Gradio')
|
group = parser.add_argument_group('Gradio')
|
||||||
|
|
@ -170,7 +157,7 @@ group.add_argument('--portable', action='store_true', help='Hide features not av
|
||||||
# API
|
# API
|
||||||
group = parser.add_argument_group('API')
|
group = parser.add_argument_group('API')
|
||||||
group.add_argument('--api', action='store_true', help='Enable the API extension.')
|
group.add_argument('--api', action='store_true', help='Enable the API extension.')
|
||||||
group.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
|
group.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudflare.')
|
||||||
group.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None)
|
group.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None)
|
||||||
group.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.')
|
group.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.')
|
||||||
group.add_argument('--api-key', type=str, default='', help='API authentication key.')
|
group.add_argument('--api-key', type=str, default='', help='API authentication key.')
|
||||||
|
|
@ -179,8 +166,53 @@ group.add_argument('--api-enable-ipv6', action='store_true', help='Enable IPv6 f
|
||||||
group.add_argument('--api-disable-ipv4', action='store_true', help='Disable IPv4 for the API')
|
group.add_argument('--api-disable-ipv4', action='store_true', help='Disable IPv4 for the API')
|
||||||
group.add_argument('--nowebui', action='store_true', help='Do not launch the Gradio UI. Useful for launching the API in standalone mode.')
|
group.add_argument('--nowebui', action='store_true', help='Do not launch the Gradio UI. Useful for launching the API in standalone mode.')
|
||||||
|
|
||||||
|
# API generation defaults
|
||||||
|
_d = default_preset_values
|
||||||
|
group = parser.add_argument_group('API generation defaults')
|
||||||
|
group.add_argument('--temperature', type=float, default=_d['temperature'], metavar='N', help='Temperature')
|
||||||
|
group.add_argument('--dynatemp-low', type=float, default=_d['dynatemp_low'], metavar='N', help='Dynamic temperature low')
|
||||||
|
group.add_argument('--dynatemp-high', type=float, default=_d['dynatemp_high'], metavar='N', help='Dynamic temperature high')
|
||||||
|
group.add_argument('--dynatemp-exponent', type=float, default=_d['dynatemp_exponent'], metavar='N', help='Dynamic temperature exponent')
|
||||||
|
group.add_argument('--smoothing-factor', type=float, default=_d['smoothing_factor'], metavar='N', help='Smoothing factor')
|
||||||
|
group.add_argument('--smoothing-curve', type=float, default=_d['smoothing_curve'], metavar='N', help='Smoothing curve')
|
||||||
|
group.add_argument('--top-p', type=float, default=_d['top_p'], metavar='N', help='Top P')
|
||||||
|
group.add_argument('--top-k', type=int, default=_d['top_k'], metavar='N', help='Top K')
|
||||||
|
group.add_argument('--min-p', type=float, default=_d['min_p'], metavar='N', help='Min P')
|
||||||
|
group.add_argument('--top-n-sigma', type=float, default=_d['top_n_sigma'], metavar='N', help='Top N Sigma')
|
||||||
|
group.add_argument('--typical-p', type=float, default=_d['typical_p'], metavar='N', help='Typical P')
|
||||||
|
group.add_argument('--xtc-threshold', type=float, default=_d['xtc_threshold'], metavar='N', help='XTC threshold')
|
||||||
|
group.add_argument('--xtc-probability', type=float, default=_d['xtc_probability'], metavar='N', help='XTC probability')
|
||||||
|
group.add_argument('--epsilon-cutoff', type=float, default=_d['epsilon_cutoff'], metavar='N', help='Epsilon cutoff')
|
||||||
|
group.add_argument('--eta-cutoff', type=float, default=_d['eta_cutoff'], metavar='N', help='Eta cutoff')
|
||||||
|
group.add_argument('--tfs', type=float, default=_d['tfs'], metavar='N', help='TFS')
|
||||||
|
group.add_argument('--top-a', type=float, default=_d['top_a'], metavar='N', help='Top A')
|
||||||
|
group.add_argument('--adaptive-target', type=float, default=_d['adaptive_target'], metavar='N', help='Adaptive target')
|
||||||
|
group.add_argument('--adaptive-decay', type=float, default=_d['adaptive_decay'], metavar='N', help='Adaptive decay')
|
||||||
|
group.add_argument('--dry-multiplier', type=float, default=_d['dry_multiplier'], metavar='N', help='DRY multiplier')
|
||||||
|
group.add_argument('--dry-allowed-length', type=int, default=_d['dry_allowed_length'], metavar='N', help='DRY allowed length')
|
||||||
|
group.add_argument('--dry-base', type=float, default=_d['dry_base'], metavar='N', help='DRY base')
|
||||||
|
group.add_argument('--repetition-penalty', type=float, default=_d['repetition_penalty'], metavar='N', help='Repetition penalty')
|
||||||
|
group.add_argument('--frequency-penalty', type=float, default=_d['frequency_penalty'], metavar='N', help='Frequency penalty')
|
||||||
|
group.add_argument('--presence-penalty', type=float, default=_d['presence_penalty'], metavar='N', help='Presence penalty')
|
||||||
|
group.add_argument('--encoder-repetition-penalty', type=float, default=_d['encoder_repetition_penalty'], metavar='N', help='Encoder repetition penalty')
|
||||||
|
group.add_argument('--no-repeat-ngram-size', type=int, default=_d['no_repeat_ngram_size'], metavar='N', help='No repeat ngram size')
|
||||||
|
group.add_argument('--repetition-penalty-range', type=int, default=_d['repetition_penalty_range'], metavar='N', help='Repetition penalty range')
|
||||||
|
group.add_argument('--penalty-alpha', type=float, default=_d['penalty_alpha'], metavar='N', help='Penalty alpha')
|
||||||
|
group.add_argument('--guidance-scale', type=float, default=_d['guidance_scale'], metavar='N', help='Guidance scale')
|
||||||
|
group.add_argument('--mirostat-mode', type=int, default=_d['mirostat_mode'], metavar='N', help='Mirostat mode')
|
||||||
|
group.add_argument('--mirostat-tau', type=float, default=_d['mirostat_tau'], metavar='N', help='Mirostat tau')
|
||||||
|
group.add_argument('--mirostat-eta', type=float, default=_d['mirostat_eta'], metavar='N', help='Mirostat eta')
|
||||||
|
group.add_argument('--do-sample', action=argparse.BooleanOptionalAction, default=_d['do_sample'], help='Do sample')
|
||||||
|
group.add_argument('--dynamic-temperature', action=argparse.BooleanOptionalAction, default=_d['dynamic_temperature'], help='Dynamic temperature')
|
||||||
|
group.add_argument('--temperature-last', action=argparse.BooleanOptionalAction, default=_d['temperature_last'], help='Temperature last')
|
||||||
|
group.add_argument('--sampler-priority', type=str, default=_d['sampler_priority'], metavar='N', help='Sampler priority')
|
||||||
|
group.add_argument('--dry-sequence-breakers', type=str, default=_d['dry_sequence_breakers'], metavar='N', help='DRY sequence breakers')
|
||||||
|
group.add_argument('--enable-thinking', action=argparse.BooleanOptionalAction, default=True, help='Enable thinking')
|
||||||
|
group.add_argument('--reasoning-effort', type=str, default='medium', metavar='N', help='Reasoning effort')
|
||||||
|
group.add_argument('--chat-template-file', type=str, default=None, help='Path to a chat template file (.jinja, .jinja2, or .yaml) to use as the default instruction template for API requests. Overrides the model\'s built-in template.')
|
||||||
|
|
||||||
# Handle CMD_FLAGS.txt
|
# Handle CMD_FLAGS.txt
|
||||||
cmd_flags_path = Path(__file__).parent.parent / "user_data" / "CMD_FLAGS.txt"
|
cmd_flags_path = user_data_dir / "CMD_FLAGS.txt"
|
||||||
if cmd_flags_path.exists():
|
if cmd_flags_path.exists():
|
||||||
with cmd_flags_path.open('r', encoding='utf-8') as f:
|
with cmd_flags_path.open('r', encoding='utf-8') as f:
|
||||||
cmd_flags = ' '.join(
|
cmd_flags = ' '.join(
|
||||||
|
|
@ -195,6 +227,7 @@ if cmd_flags_path.exists():
|
||||||
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
user_data_dir = Path(args.user_data_dir) # Update from parsed args (may differ from pre-parse)
|
||||||
original_args = copy.deepcopy(args)
|
original_args = copy.deepcopy(args)
|
||||||
args_defaults = parser.parse_args([])
|
args_defaults = parser.parse_args([])
|
||||||
|
|
||||||
|
|
@ -224,8 +257,9 @@ settings = {
|
||||||
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>". Reply directly, without starting the reply with the character name.\n\n<|prompt|>',
|
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>". Reply directly, without starting the reply with the character name.\n\n<|prompt|>',
|
||||||
'enable_web_search': False,
|
'enable_web_search': False,
|
||||||
'web_search_pages': 3,
|
'web_search_pages': 3,
|
||||||
|
'selected_tools': [],
|
||||||
'prompt-notebook': '',
|
'prompt-notebook': '',
|
||||||
'preset': 'Qwen3 - Thinking' if Path('user_data/presets/Qwen3 - Thinking.yaml').exists() else None,
|
'preset': 'Top-P' if (user_data_dir / 'presets/Top-P.yaml').exists() else None,
|
||||||
'max_new_tokens': 512,
|
'max_new_tokens': 512,
|
||||||
'max_new_tokens_min': 1,
|
'max_new_tokens_min': 1,
|
||||||
'max_new_tokens_max': 4096,
|
'max_new_tokens_max': 4096,
|
||||||
|
|
@ -250,7 +284,7 @@ settings = {
|
||||||
'include_past_attachments': True,
|
'include_past_attachments': True,
|
||||||
|
|
||||||
# Generation parameters - Curve shape
|
# Generation parameters - Curve shape
|
||||||
'temperature': 0.6,
|
'temperature': neutral_samplers['temperature'],
|
||||||
'dynatemp_low': neutral_samplers['dynatemp_low'],
|
'dynatemp_low': neutral_samplers['dynatemp_low'],
|
||||||
'dynatemp_high': neutral_samplers['dynatemp_high'],
|
'dynatemp_high': neutral_samplers['dynatemp_high'],
|
||||||
'dynatemp_exponent': neutral_samplers['dynatemp_exponent'],
|
'dynatemp_exponent': neutral_samplers['dynatemp_exponent'],
|
||||||
|
|
@ -258,9 +292,10 @@ settings = {
|
||||||
'smoothing_curve': neutral_samplers['smoothing_curve'],
|
'smoothing_curve': neutral_samplers['smoothing_curve'],
|
||||||
|
|
||||||
# Generation parameters - Curve cutoff
|
# Generation parameters - Curve cutoff
|
||||||
'min_p': neutral_samplers['min_p'],
|
|
||||||
'top_p': 0.95,
|
'top_p': 0.95,
|
||||||
'top_k': 20,
|
'top_k': neutral_samplers['top_k'],
|
||||||
|
'min_p': neutral_samplers['min_p'],
|
||||||
|
'top_n_sigma': neutral_samplers['top_n_sigma'],
|
||||||
'typical_p': neutral_samplers['typical_p'],
|
'typical_p': neutral_samplers['typical_p'],
|
||||||
'xtc_threshold': neutral_samplers['xtc_threshold'],
|
'xtc_threshold': neutral_samplers['xtc_threshold'],
|
||||||
'xtc_probability': neutral_samplers['xtc_probability'],
|
'xtc_probability': neutral_samplers['xtc_probability'],
|
||||||
|
|
@ -268,7 +303,8 @@ settings = {
|
||||||
'eta_cutoff': neutral_samplers['eta_cutoff'],
|
'eta_cutoff': neutral_samplers['eta_cutoff'],
|
||||||
'tfs': neutral_samplers['tfs'],
|
'tfs': neutral_samplers['tfs'],
|
||||||
'top_a': neutral_samplers['top_a'],
|
'top_a': neutral_samplers['top_a'],
|
||||||
'top_n_sigma': neutral_samplers['top_n_sigma'],
|
'adaptive_target': neutral_samplers['adaptive_target'],
|
||||||
|
'adaptive_decay': neutral_samplers['adaptive_decay'],
|
||||||
|
|
||||||
# Generation parameters - Repetition suppression
|
# Generation parameters - Repetition suppression
|
||||||
'dry_multiplier': neutral_samplers['dry_multiplier'],
|
'dry_multiplier': neutral_samplers['dry_multiplier'],
|
||||||
|
|
@ -298,6 +334,7 @@ settings = {
|
||||||
|
|
||||||
# Character settings
|
# Character settings
|
||||||
'character': 'Assistant',
|
'character': 'Assistant',
|
||||||
|
'user': 'Default',
|
||||||
'name1': 'You',
|
'name1': 'You',
|
||||||
'name2': 'AI',
|
'name2': 'AI',
|
||||||
'user_bio': '',
|
'user_bio': '',
|
||||||
|
|
@ -305,7 +342,7 @@ settings = {
|
||||||
'greeting': 'How can I help you today?',
|
'greeting': 'How can I help you today?',
|
||||||
'custom_system_message': '',
|
'custom_system_message': '',
|
||||||
'instruction_template_str': "{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if not ns.found -%}\n {{- '' + 'Below is an instruction that describes a task. Write a response that appropriately completes the request.' + '\\n\\n' -}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {{- '' + message['content'] + '\\n\\n' -}}\n {%- else -%}\n {%- if message['role'] == 'user' -%}\n {{-'### Instruction:\\n' + message['content'] + '\\n\\n'-}}\n {%- else -%}\n {{-'### Response:\\n' + message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{-'### Response:\\n'-}}\n{%- endif -%}",
|
'instruction_template_str': "{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if not ns.found -%}\n {{- '' + 'Below is an instruction that describes a task. Write a response that appropriately completes the request.' + '\\n\\n' -}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {{- '' + message['content'] + '\\n\\n' -}}\n {%- else -%}\n {%- if message['role'] == 'user' -%}\n {{-'### Instruction:\\n' + message['content'] + '\\n\\n'-}}\n {%- else -%}\n {{-'### Response:\\n' + message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{-'### Response:\\n'-}}\n{%- endif -%}",
|
||||||
'chat_template_str': "{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {%- if message['content'] -%}\n {{- message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- if user_bio -%}\n {{- user_bio + '\\n\\n' -}}\n {%- endif -%}\n {%- else -%}\n {%- if message['role'] == 'user' -%}\n {{- name1 + ': ' + message['content'] + '\\n'-}}\n {%- else -%}\n {{- name2 + ': ' + message['content'] + '\\n' -}}\n {%- endif -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt %}\n {{- name2 + ':' -}}\n{%- endif %}",
|
'chat_template_str': "{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {%- if message['content'] -%}\n {{- message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- if user_bio -%}\n {{- user_bio + '\\n\\n' -}}\n {%- endif -%}\n {%- elif message['role'] == 'tool' -%}\n {{- '[Tool result: ' + message['content'] + ']\\n' -}}\n {%- elif message['role'] == 'user' -%}\n {{- name1 + ': ' + message['content'] + '\\n'-}}\n {%- elif message['tool_calls'] is defined and message['tool_calls'] -%}\n {%- for tc in message['tool_calls'] -%}\n {{- '[Calling: ' + tc['function']['name'] + '(' + tc['function']['arguments'] + ')]\\n' -}}\n {%- endfor -%}\n {%- else -%}\n {{- name2 + ': ' + message['content'] + '\\n' -}}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt %}\n {{- name2 + ':' -}}\n{%- endif %}",
|
||||||
|
|
||||||
# Extensions
|
# Extensions
|
||||||
'default_extensions': [],
|
'default_extensions': [],
|
||||||
|
|
@ -335,6 +372,11 @@ default_settings = copy.deepcopy(settings)
|
||||||
|
|
||||||
|
|
||||||
def do_cmd_flags_warnings():
|
def do_cmd_flags_warnings():
|
||||||
|
# Validate --chat-template-file
|
||||||
|
if args.chat_template_file and not Path(args.chat_template_file).is_file():
|
||||||
|
logger.error(f"--chat-template-file: file not found: {args.chat_template_file}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
# Security warnings
|
# Security warnings
|
||||||
if args.trust_remote_code:
|
if args.trust_remote_code:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
@ -348,9 +390,16 @@ def do_cmd_flags_warnings():
|
||||||
if args.share:
|
if args.share:
|
||||||
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
|
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
|
||||||
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)):
|
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)):
|
||||||
logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
|
logger.warning("You are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
|
||||||
if args.multi_user:
|
if args.multi_user:
|
||||||
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.')
|
logger.warning(
|
||||||
|
'Multi-user mode is enabled. Known limitations:'
|
||||||
|
'\n- The Stop button stops generation for all users, not just you.'
|
||||||
|
'\n- Chat history is not saved and will be lost on page refresh.'
|
||||||
|
'\n- Only one user can generate at a time unless using a parallel-capable backend (e.g. llama.cpp with --parallel N for N > 1, or ExLlamaV3).'
|
||||||
|
'\n\nThis mode works best for small trusted teams.'
|
||||||
|
'\n\nDo not expose publicly. Grayed-out actions can easily be bypassed client-side.\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def apply_image_model_cli_overrides():
|
def apply_image_model_cli_overrides():
|
||||||
|
|
@ -378,10 +427,6 @@ def fix_loader_name(name):
|
||||||
return 'llama.cpp'
|
return 'llama.cpp'
|
||||||
elif name in ['transformers', 'huggingface', 'hf', 'hugging_face', 'hugging face']:
|
elif name in ['transformers', 'huggingface', 'hf', 'hugging_face', 'hugging face']:
|
||||||
return 'Transformers'
|
return 'Transformers'
|
||||||
elif name in ['exllamav2', 'exllama-v2', 'ex_llama-v2', 'exlamav2', 'exlama-v2', 'exllama2', 'exllama-2']:
|
|
||||||
return 'ExLlamav2'
|
|
||||||
elif name in ['exllamav2-hf', 'exllamav2_hf', 'exllama-v2-hf', 'exllama_v2_hf', 'exllama-v2_hf', 'exllama2-hf', 'exllama2_hf', 'exllama-2-hf', 'exllama_2_hf', 'exllama-2_hf']:
|
|
||||||
return 'ExLlamav2_HF'
|
|
||||||
elif name in ['exllamav3-hf', 'exllamav3_hf', 'exllama-v3-hf', 'exllama_v3_hf', 'exllama-v3_hf', 'exllama3-hf', 'exllama3_hf', 'exllama-3-hf', 'exllama_3_hf', 'exllama-3_hf']:
|
elif name in ['exllamav3-hf', 'exllamav3_hf', 'exllama-v3-hf', 'exllama_v3_hf', 'exllama-v3_hf', 'exllama3-hf', 'exllama3_hf', 'exllama-3-hf', 'exllama_3_hf', 'exllama-3_hf']:
|
||||||
return 'ExLlamav3_HF'
|
return 'ExLlamav3_HF'
|
||||||
elif name in ['exllamav3']:
|
elif name in ['exllamav3']:
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,10 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import tensorrt_llm
|
from tensorrt_llm._tensorrt_engine import LLM
|
||||||
import torch
|
from tensorrt_llm.llmapi import SamplingParams
|
||||||
from tensorrt_llm.runtime import ModelRunner, ModelRunnerCpp
|
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.text_generation import (
|
|
||||||
get_max_prompt_length,
|
|
||||||
get_reply_from_output_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TensorRTLLMModel:
|
class TensorRTLLMModel:
|
||||||
|
|
@ -17,110 +12,50 @@ class TensorRTLLMModel:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(self, path_to_model):
|
def from_pretrained(cls, path_to_model):
|
||||||
|
|
||||||
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
|
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
|
||||||
runtime_rank = tensorrt_llm.mpi_rank()
|
|
||||||
|
|
||||||
# Define model settings
|
llm = LLM(
|
||||||
runner_kwargs = dict(
|
model=str(path_to_model),
|
||||||
engine_dir=str(path_to_model),
|
skip_tokenizer_init=False,
|
||||||
lora_dir=None,
|
|
||||||
rank=runtime_rank,
|
|
||||||
debug_mode=False,
|
|
||||||
lora_ckpt_source="hf",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if shared.args.cpp_runner:
|
result = cls()
|
||||||
logger.info("TensorRT-LLM: Using \"ModelRunnerCpp\"")
|
result.llm = llm
|
||||||
runner_kwargs.update(
|
result.tokenizer = llm.tokenizer
|
||||||
max_batch_size=1,
|
|
||||||
max_input_len=shared.args.ctx_size - 512,
|
|
||||||
max_output_len=512,
|
|
||||||
max_beam_width=1,
|
|
||||||
max_attention_window_size=None,
|
|
||||||
sink_token_length=None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info("TensorRT-LLM: Using \"ModelRunner\"")
|
|
||||||
|
|
||||||
# Load the model
|
|
||||||
runner_cls = ModelRunnerCpp if shared.args.cpp_runner else ModelRunner
|
|
||||||
runner = runner_cls.from_dir(**runner_kwargs)
|
|
||||||
|
|
||||||
result = self()
|
|
||||||
result.model = runner
|
|
||||||
result.runtime_rank = runtime_rank
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def generate_with_streaming(self, prompt, state):
|
def generate_with_streaming(self, prompt, state):
|
||||||
batch_input_ids = []
|
sampling_params = SamplingParams(
|
||||||
input_ids = shared.tokenizer.encode(
|
max_tokens=state['max_new_tokens'] if not state['auto_max_new_tokens']
|
||||||
prompt,
|
else state['truncation_length'] - len(shared.tokenizer.encode(prompt)),
|
||||||
add_special_tokens=True,
|
end_id=shared.tokenizer.eos_token_id,
|
||||||
truncation=False,
|
|
||||||
)
|
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int32)
|
|
||||||
input_ids = input_ids[-get_max_prompt_length(state):] # Apply truncation_length
|
|
||||||
batch_input_ids.append(input_ids)
|
|
||||||
|
|
||||||
if shared.args.cpp_runner:
|
|
||||||
max_new_tokens = min(512, state['max_new_tokens'])
|
|
||||||
elif state['auto_max_new_tokens']:
|
|
||||||
max_new_tokens = state['truncation_length'] - input_ids.shape[-1]
|
|
||||||
else:
|
|
||||||
max_new_tokens = state['max_new_tokens']
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
generator = self.model.generate(
|
|
||||||
batch_input_ids,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
max_attention_window_size=None,
|
|
||||||
sink_token_length=None,
|
|
||||||
end_id=shared.tokenizer.eos_token_id if not state['ban_eos_token'] else -1,
|
|
||||||
pad_id=shared.tokenizer.pad_token_id or shared.tokenizer.eos_token_id,
|
|
||||||
temperature=state['temperature'],
|
temperature=state['temperature'],
|
||||||
top_k=state['top_k'],
|
top_k=state['top_k'],
|
||||||
top_p=state['top_p'],
|
top_p=state['top_p'],
|
||||||
num_beams=1,
|
min_p=state['min_p'],
|
||||||
length_penalty=1.0,
|
|
||||||
repetition_penalty=state['repetition_penalty'],
|
repetition_penalty=state['repetition_penalty'],
|
||||||
presence_penalty=state['presence_penalty'],
|
presence_penalty=state['presence_penalty'],
|
||||||
frequency_penalty=state['frequency_penalty'],
|
frequency_penalty=state['frequency_penalty'],
|
||||||
stop_words_list=None,
|
no_repeat_ngram_size=state['no_repeat_ngram_size'] if state['no_repeat_ngram_size'] > 0 else None,
|
||||||
bad_words_list=None,
|
seed=state['seed'],
|
||||||
lora_uids=None,
|
ignore_eos=state['ban_eos_token'],
|
||||||
prompt_table_path=None,
|
add_special_tokens=state['add_bos_token'],
|
||||||
prompt_tasks=None,
|
skip_special_tokens=state['skip_special_tokens'],
|
||||||
streaming=not shared.args.cpp_runner,
|
|
||||||
output_sequence_lengths=True,
|
|
||||||
return_dict=True,
|
|
||||||
medusa_choices=None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
stop_event = state.get('stop_event')
|
||||||
|
result = self.llm.generate_async(prompt, sampling_params=sampling_params, streaming=True)
|
||||||
|
|
||||||
cumulative_reply = ''
|
cumulative_reply = ''
|
||||||
starting_from = batch_input_ids[0].shape[-1]
|
for output in result:
|
||||||
|
if shared.stop_everything or (stop_event and stop_event.is_set()):
|
||||||
if shared.args.cpp_runner:
|
result.abort()
|
||||||
sequence_length = generator['sequence_lengths'][0].item()
|
|
||||||
output_ids = generator['output_ids'][0][0][:sequence_length].tolist()
|
|
||||||
|
|
||||||
cumulative_reply += get_reply_from_output_ids(output_ids, state, starting_from=starting_from)
|
|
||||||
starting_from = sequence_length
|
|
||||||
yield cumulative_reply
|
|
||||||
else:
|
|
||||||
for curr_outputs in generator:
|
|
||||||
if shared.stop_everything:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
sequence_length = curr_outputs['sequence_lengths'][0].item()
|
text_diff = output.outputs[0].text_diff
|
||||||
output_ids = curr_outputs['output_ids'][0][0][:sequence_length].tolist()
|
if text_diff:
|
||||||
|
cumulative_reply += text_diff
|
||||||
cumulative_reply += get_reply_from_output_ids(output_ids, state, starting_from=starting_from)
|
|
||||||
starting_from = sequence_length
|
|
||||||
yield cumulative_reply
|
yield cumulative_reply
|
||||||
|
|
||||||
def generate(self, prompt, state):
|
def generate(self, prompt, state):
|
||||||
|
|
@ -129,3 +64,8 @@ class TensorRTLLMModel:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def unload(self):
|
||||||
|
if hasattr(self, 'llm') and self.llm is not None:
|
||||||
|
self.llm.shutdown()
|
||||||
|
self.llm = None
|
||||||
|
|
|
||||||
|
|
@ -22,12 +22,22 @@ def generate_reply(*args, **kwargs):
|
||||||
from modules.models import load_model
|
from modules.models import load_model
|
||||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||||
|
|
||||||
|
state = args[1] if len(args) > 1 else kwargs.get('state', {})
|
||||||
|
use_parallel = (
|
||||||
|
state.get('stop_event') is not None
|
||||||
|
and shared.model.__class__.__name__ in ['Exllamav3Model', 'LlamaServer', 'TensorRTLLMModel']
|
||||||
|
and (shared.model.__class__.__name__ != 'LlamaServer' or shared.args.parallel > 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not use_parallel:
|
||||||
shared.generation_lock.acquire()
|
shared.generation_lock.acquire()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for result in _generate_reply(*args, **kwargs):
|
for result in _generate_reply(*args, **kwargs):
|
||||||
yield result
|
yield result
|
||||||
finally:
|
finally:
|
||||||
models.last_generation_time = time.time()
|
models.last_generation_time = time.time()
|
||||||
|
if not use_parallel:
|
||||||
shared.generation_lock.release()
|
shared.generation_lock.release()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -40,7 +50,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
||||||
yield ''
|
yield ''
|
||||||
return
|
return
|
||||||
|
|
||||||
if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'Exllamav3Model', 'TensorRTLLMModel']:
|
if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav3Model', 'TensorRTLLMModel']:
|
||||||
generate_func = generate_reply_custom
|
generate_func = generate_reply_custom
|
||||||
else:
|
else:
|
||||||
generate_func = generate_reply_HF
|
generate_func = generate_reply_HF
|
||||||
|
|
@ -68,7 +78,13 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
||||||
reply = ''
|
reply = ''
|
||||||
is_stream = state['stream']
|
is_stream = state['stream']
|
||||||
if len(all_stop_strings) > 0 and not state['stream']:
|
if len(all_stop_strings) > 0 and not state['stream']:
|
||||||
|
original_logits_processor = state.get('logits_processor')
|
||||||
|
stop_event_ref = state.pop('stop_event', None)
|
||||||
state = copy.deepcopy(state)
|
state = copy.deepcopy(state)
|
||||||
|
if stop_event_ref is not None:
|
||||||
|
state['stop_event'] = stop_event_ref
|
||||||
|
if original_logits_processor is not None:
|
||||||
|
state['logits_processor'] = original_logits_processor
|
||||||
state['stream'] = True
|
state['stream'] = True
|
||||||
|
|
||||||
# Generate
|
# Generate
|
||||||
|
|
@ -99,7 +115,8 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
||||||
yield reply
|
yield reply
|
||||||
last_update = time.monotonic()
|
last_update = time.monotonic()
|
||||||
|
|
||||||
if stop_found or (state['max_tokens_second'] > 0 and shared.stop_everything):
|
stop_event = state.get('stop_event')
|
||||||
|
if stop_found or shared.stop_everything or (stop_event and stop_event.is_set()):
|
||||||
break
|
break
|
||||||
|
|
||||||
if not is_chat:
|
if not is_chat:
|
||||||
|
|
@ -128,9 +145,9 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
||||||
|
|
||||||
from modules.torch_utils import get_device
|
from modules.torch_utils import get_device
|
||||||
|
|
||||||
if shared.model.__class__.__name__ in ['Exllamav2Model', 'Exllamav3Model', 'TensorRTLLMModel']:
|
if shared.model.__class__.__name__ in ['Exllamav3Model', 'TensorRTLLMModel']:
|
||||||
input_ids = shared.tokenizer.encode(str(prompt))
|
input_ids = shared.tokenizer.encode(str(prompt))
|
||||||
if shared.model.__class__.__name__ not in ['Exllamav2Model', 'Exllamav3Model']:
|
if shared.model.__class__.__name__ not in ['Exllamav3Model']:
|
||||||
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
||||||
else:
|
else:
|
||||||
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
|
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
|
||||||
|
|
@ -148,7 +165,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
||||||
if truncation_length is not None:
|
if truncation_length is not None:
|
||||||
input_ids = input_ids[:, -truncation_length:]
|
input_ids = input_ids[:, -truncation_length:]
|
||||||
|
|
||||||
if shared.model.__class__.__name__ in ['Exllamav2Model', 'Exllamav3Model', 'TensorRTLLMModel'] or shared.args.cpu:
|
if shared.model.__class__.__name__ in ['Exllamav3Model', 'TensorRTLLMModel'] or shared.args.cpu:
|
||||||
return input_ids
|
return input_ids
|
||||||
else:
|
else:
|
||||||
device = get_device()
|
device = get_device()
|
||||||
|
|
@ -317,6 +334,8 @@ def generate_reply_HF(question, original_question, state, stopping_strings=None,
|
||||||
'tfs',
|
'tfs',
|
||||||
'top_a',
|
'top_a',
|
||||||
'top_n_sigma',
|
'top_n_sigma',
|
||||||
|
'adaptive_target',
|
||||||
|
'adaptive_decay',
|
||||||
'dry_multiplier',
|
'dry_multiplier',
|
||||||
'dry_allowed_length',
|
'dry_allowed_length',
|
||||||
'dry_base',
|
'dry_base',
|
||||||
|
|
@ -359,7 +378,7 @@ def generate_reply_HF(question, original_question, state, stopping_strings=None,
|
||||||
generate_params['sampler_priority'] = [x.strip() for x in state['sampler_priority'].replace('\n', ',').split(',') if x.strip()]
|
generate_params['sampler_priority'] = [x.strip() for x in state['sampler_priority'].replace('\n', ',').split(',') if x.strip()]
|
||||||
|
|
||||||
if state['custom_token_bans']:
|
if state['custom_token_bans']:
|
||||||
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
|
to_ban = [int(x.strip()) for x in state['custom_token_bans'].split(',') if x.strip()]
|
||||||
if len(to_ban) > 0:
|
if len(to_ban) > 0:
|
||||||
if generate_params.get('suppress_tokens', None):
|
if generate_params.get('suppress_tokens', None):
|
||||||
generate_params['suppress_tokens'] += to_ban
|
generate_params['suppress_tokens'] += to_ban
|
||||||
|
|
@ -370,8 +389,6 @@ def generate_reply_HF(question, original_question, state, stopping_strings=None,
|
||||||
generate_params['negative_prompt_ids'] = encode(state['negative_prompt'])
|
generate_params['negative_prompt_ids'] = encode(state['negative_prompt'])
|
||||||
|
|
||||||
generate_params.update({'use_cache': not shared.args.no_cache})
|
generate_params.update({'use_cache': not shared.args.no_cache})
|
||||||
if shared.args.deepspeed:
|
|
||||||
generate_params.update({'synced_gpus': True})
|
|
||||||
|
|
||||||
# Encode the input
|
# Encode the input
|
||||||
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
||||||
|
|
@ -474,7 +491,10 @@ def generate_reply_custom(question, original_question, state, stopping_strings=N
|
||||||
For models that do not use the transformers library for sampling
|
For models that do not use the transformers library for sampling
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
stop_event_ref = state.pop('stop_event', None)
|
||||||
state = copy.deepcopy(state)
|
state = copy.deepcopy(state)
|
||||||
|
if stop_event_ref is not None:
|
||||||
|
state['stop_event'] = stop_event_ref
|
||||||
state['seed'] = set_manual_seed(state['seed'])
|
state['seed'] = set_manual_seed(state['seed'])
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
reply = ''
|
reply = ''
|
||||||
|
|
|
||||||
667
modules/tool_parsing.py
Normal file
667
modules/tool_parsing.py
Normal file
|
|
@ -0,0 +1,667 @@
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_call_id() -> str:
|
||||||
|
letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||||
|
b = [random.choice(letter_bytes) for _ in range(8)]
|
||||||
|
return "call_" + "".join(b).lower()
|
||||||
|
|
||||||
|
|
||||||
|
# All known opening markers for tool calls across model formats.
|
||||||
|
TOOL_CALL_OPENING_MARKERS = [
|
||||||
|
'<tool_call>',
|
||||||
|
'<function_call>',
|
||||||
|
'<minimax:tool_call>',
|
||||||
|
'<|tool_call_begin|>',
|
||||||
|
'<|tool_calls_section_begin|>',
|
||||||
|
'<|tool▁call▁begin|>',
|
||||||
|
'<|tool▁calls▁begin|>',
|
||||||
|
'[TOOL_CALLS]',
|
||||||
|
'to=functions.',
|
||||||
|
'<|channel|>commentary',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def streaming_tool_buffer_check(text, markers=None, tool_names=None, check_bare_names=False):
|
||||||
|
'''
|
||||||
|
Check whether streaming output should be withheld because it may
|
||||||
|
contain tool-call markup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Full accumulated internal text.
|
||||||
|
markers: Template-specific markers for partial-prefix matching.
|
||||||
|
If None, falls back to TOOL_CALL_OPENING_MARKERS.
|
||||||
|
tool_names: List of tool function names.
|
||||||
|
check_bare_names: Whether to do partial-prefix matching on tool
|
||||||
|
names (for models with unknown template format).
|
||||||
|
'''
|
||||||
|
# Full marker found in text → buffer permanently.
|
||||||
|
# Always checks ALL known markers regardless of template (cheap safety net).
|
||||||
|
for marker in TOOL_CALL_OPENING_MARKERS:
|
||||||
|
if marker in text:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Bare function-name full match: "get_weather{...}" or "get_weather {...}"
|
||||||
|
if tool_names:
|
||||||
|
for name in tool_names:
|
||||||
|
if name + '{' in text or name + ' {' in text:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Partial-prefix matching: only for template-specific markers.
|
||||||
|
for marker in (markers if markers is not None else TOOL_CALL_OPENING_MARKERS):
|
||||||
|
for prefix_len in range(min(len(marker) - 1, len(text)), 0, -1):
|
||||||
|
if text.endswith(marker[:prefix_len]):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Bare-name partial matching: only when template format is unknown.
|
||||||
|
if check_bare_names and tool_names:
|
||||||
|
for name in tool_names:
|
||||||
|
if text.endswith(name):
|
||||||
|
return True
|
||||||
|
for prefix_len in range(min(len(name) - 1, len(text)), 0, -1):
|
||||||
|
if text.endswith(name[:prefix_len]):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def check_and_sanitize_tool_call_candidate(candidate_dict: dict, tool_names: list[str]):
|
||||||
|
# check if property 'function' exists and is a dictionary, otherwise adapt dict
|
||||||
|
if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str):
|
||||||
|
candidate_dict = {"type": "function", "function": candidate_dict}
|
||||||
|
if 'function' in candidate_dict and isinstance(candidate_dict['function'], str):
|
||||||
|
candidate_dict['name'] = candidate_dict['function']
|
||||||
|
del candidate_dict['function']
|
||||||
|
candidate_dict = {"type": "function", "function": candidate_dict}
|
||||||
|
if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict):
|
||||||
|
# check if 'name' exists within 'function' and is part of known tools
|
||||||
|
if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names:
|
||||||
|
candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value
|
||||||
|
# map property 'parameters' used by some older models to 'arguments'
|
||||||
|
if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]:
|
||||||
|
candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"]
|
||||||
|
del candidate_dict["function"]["parameters"]
|
||||||
|
return candidate_dict
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_balanced_json(text: str, start: int) -> str | None:
|
||||||
|
"""Extract a balanced JSON object from text starting at the given position.
|
||||||
|
|
||||||
|
Walks through the string tracking brace depth and string boundaries
|
||||||
|
to correctly handle arbitrary nesting levels.
|
||||||
|
"""
|
||||||
|
if start >= len(text) or text[start] != '{':
|
||||||
|
return None
|
||||||
|
depth = 0
|
||||||
|
in_string = False
|
||||||
|
escape_next = False
|
||||||
|
for i in range(start, len(text)):
|
||||||
|
c = text[i]
|
||||||
|
if escape_next:
|
||||||
|
escape_next = False
|
||||||
|
continue
|
||||||
|
if c == '\\' and in_string:
|
||||||
|
escape_next = True
|
||||||
|
continue
|
||||||
|
if c == '"':
|
||||||
|
in_string = not in_string
|
||||||
|
continue
|
||||||
|
if in_string:
|
||||||
|
continue
|
||||||
|
if c == '{':
|
||||||
|
depth += 1
|
||||||
|
elif c == '}':
|
||||||
|
depth -= 1
|
||||||
|
if depth == 0:
|
||||||
|
return text[start:i + 1]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_channel_tool_calls(answer: str, tool_names: list[str]):
|
||||||
|
"""Parse channel-based tool calls used by GPT-OSS and similar models.
|
||||||
|
|
||||||
|
Format:
|
||||||
|
<|start|>assistant to=functions.func_name<|channel|>commentary json<|message|>{"arg": "value"}
|
||||||
|
or:
|
||||||
|
<|channel|>commentary to=functions.func_name <|constrain|>json<|message|>{"arg": "value"}
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
start_pos = None
|
||||||
|
# Pattern 1: to=functions.NAME before <|channel|> (GPT-OSS primary format)
|
||||||
|
# Pattern 2: to=functions.NAME after <|channel|> (alternative format)
|
||||||
|
patterns = [
|
||||||
|
r'to=functions\.([^<\s]+)\s*<\|channel\|>[^<]*<\|message\|>',
|
||||||
|
r'<\|channel\|>\w+ to=functions\.([^<\s]+).*?<\|message\|>',
|
||||||
|
]
|
||||||
|
for pattern in patterns:
|
||||||
|
for m in re.finditer(pattern, answer):
|
||||||
|
func_name = m.group(1).strip()
|
||||||
|
if func_name not in tool_names:
|
||||||
|
continue
|
||||||
|
json_str = _extract_balanced_json(answer, m.end())
|
||||||
|
if json_str is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
arguments = json.loads(json_str)
|
||||||
|
if start_pos is None:
|
||||||
|
prefix = answer.rfind('<|start|>assistant', 0, m.start())
|
||||||
|
start_pos = prefix if prefix != -1 else m.start()
|
||||||
|
matches.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": func_name,
|
||||||
|
"arguments": arguments
|
||||||
|
}
|
||||||
|
})
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
if matches:
|
||||||
|
break
|
||||||
|
return matches, start_pos
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_mistral_token_tool_calls(answer: str, tool_names: list[str]):
|
||||||
|
"""Parse Mistral/Devstral-style tool calls with [TOOL_CALLS] and [ARGS] special tokens.
|
||||||
|
|
||||||
|
Format:
|
||||||
|
[TOOL_CALLS]func_name[ARGS]{"arg": "value"}
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
start_pos = None
|
||||||
|
for m in re.finditer(
|
||||||
|
r'\[TOOL_CALLS\]\s*(\S+?)\s*\[ARGS\]\s*',
|
||||||
|
answer
|
||||||
|
):
|
||||||
|
func_name = m.group(1).strip()
|
||||||
|
if func_name not in tool_names:
|
||||||
|
continue
|
||||||
|
json_str = _extract_balanced_json(answer, m.end())
|
||||||
|
if json_str is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
arguments = json.loads(json_str)
|
||||||
|
if start_pos is None:
|
||||||
|
start_pos = m.start()
|
||||||
|
matches.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": func_name,
|
||||||
|
"arguments": arguments
|
||||||
|
}
|
||||||
|
})
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
return matches, start_pos
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_bare_name_tool_calls(answer: str, tool_names: list[str]):
|
||||||
|
"""Parse bare function-name style tool calls used by Mistral and similar models.
|
||||||
|
|
||||||
|
Format:
|
||||||
|
functionName{"arg": "value"}
|
||||||
|
Multiple calls are concatenated directly or separated by whitespace.
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
start_pos = None
|
||||||
|
# Match tool name followed by opening brace, then extract balanced JSON
|
||||||
|
escaped_names = [re.escape(name) for name in tool_names]
|
||||||
|
pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{'
|
||||||
|
for match in re.finditer(pattern, answer):
|
||||||
|
text = match.group(0)
|
||||||
|
name = None
|
||||||
|
for n in tool_names:
|
||||||
|
if text.startswith(n):
|
||||||
|
name = n
|
||||||
|
break
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
brace_start = match.end() - 1
|
||||||
|
json_str = _extract_balanced_json(answer, brace_start)
|
||||||
|
if json_str is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
arguments = json.loads(json_str)
|
||||||
|
if start_pos is None:
|
||||||
|
start_pos = match.start()
|
||||||
|
matches.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": name,
|
||||||
|
"arguments": arguments
|
||||||
|
}
|
||||||
|
})
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
return matches, start_pos
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_xml_param_tool_calls(answer: str, tool_names: list[str]):
|
||||||
|
"""Parse XML-parameter style tool calls used by Qwen3.5 and similar models.
|
||||||
|
|
||||||
|
Format:
|
||||||
|
<tool_call>
|
||||||
|
<function=function_name>
|
||||||
|
<parameter=param_name>value</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
start_pos = None
|
||||||
|
for tc_match in re.finditer(r'<tool_call>\s*(.*?)\s*</tool_call>', answer, re.DOTALL):
|
||||||
|
tc_content = tc_match.group(1)
|
||||||
|
func_match = re.search(r'<function=([^>]+)>', tc_content)
|
||||||
|
if not func_match:
|
||||||
|
continue
|
||||||
|
func_name = func_match.group(1).strip()
|
||||||
|
if func_name not in tool_names:
|
||||||
|
continue
|
||||||
|
arguments = {}
|
||||||
|
for param_match in re.finditer(r'<parameter=([^>]+)>\s*(.*?)\s*</parameter>', tc_content, re.DOTALL):
|
||||||
|
param_name = param_match.group(1).strip()
|
||||||
|
param_value = param_match.group(2).strip()
|
||||||
|
try:
|
||||||
|
param_value = json.loads(param_value)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
pass # keep as string
|
||||||
|
arguments[param_name] = param_value
|
||||||
|
if start_pos is None:
|
||||||
|
start_pos = tc_match.start()
|
||||||
|
matches.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": func_name,
|
||||||
|
"arguments": arguments
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return matches, start_pos
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_kimi_tool_calls(answer: str, tool_names: list[str]):
|
||||||
|
"""Parse Kimi-K2-style tool calls using pipe-delimited tokens.
|
||||||
|
|
||||||
|
Format:
|
||||||
|
<|tool_calls_section_begin|>
|
||||||
|
<|tool_call_begin|>functions.func_name:index<|tool_call_argument_begin|>{"arg": "value"}<|tool_call_end|>
|
||||||
|
<|tool_calls_section_end|>
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
start_pos = None
|
||||||
|
for m in re.finditer(
|
||||||
|
r'<\|tool_call_begin\|>\s*(?:functions\.)?(\S+?)(?::\d+)?\s*<\|tool_call_argument_begin\|>\s*',
|
||||||
|
answer
|
||||||
|
):
|
||||||
|
func_name = m.group(1).strip()
|
||||||
|
if func_name not in tool_names:
|
||||||
|
continue
|
||||||
|
json_str = _extract_balanced_json(answer, m.end())
|
||||||
|
if json_str is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
arguments = json.loads(json_str)
|
||||||
|
if start_pos is None:
|
||||||
|
# Check for section begin marker before the call marker
|
||||||
|
section = answer.rfind('<|tool_calls_section_begin|>', 0, m.start())
|
||||||
|
start_pos = section if section != -1 else m.start()
|
||||||
|
matches.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": func_name,
|
||||||
|
"arguments": arguments
|
||||||
|
}
|
||||||
|
})
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
return matches, start_pos
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_minimax_tool_calls(answer: str, tool_names: list[str]):
|
||||||
|
"""Parse MiniMax-style tool calls using invoke/parameter XML tags.
|
||||||
|
|
||||||
|
Format:
|
||||||
|
<minimax:tool_call>
|
||||||
|
<invoke name="function_name">
|
||||||
|
<parameter name="param_name">value</parameter>
|
||||||
|
</invoke>
|
||||||
|
</minimax:tool_call>
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
start_pos = None
|
||||||
|
for tc_match in re.finditer(r'<minimax:tool_call>\s*(.*?)\s*</minimax:tool_call>', answer, re.DOTALL):
|
||||||
|
tc_content = tc_match.group(1)
|
||||||
|
# Split on <invoke> to handle multiple parallel calls in one block
|
||||||
|
for invoke_match in re.finditer(r'<invoke\s+name="([^"]+)">(.*?)</invoke>', tc_content, re.DOTALL):
|
||||||
|
func_name = invoke_match.group(1).strip()
|
||||||
|
if func_name not in tool_names:
|
||||||
|
continue
|
||||||
|
invoke_body = invoke_match.group(2)
|
||||||
|
arguments = {}
|
||||||
|
for param_match in re.finditer(r'<parameter\s+name="([^"]+)">\s*(.*?)\s*</parameter>', invoke_body, re.DOTALL):
|
||||||
|
param_name = param_match.group(1).strip()
|
||||||
|
param_value = param_match.group(2).strip()
|
||||||
|
try:
|
||||||
|
param_value = json.loads(param_value)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
pass # keep as string
|
||||||
|
arguments[param_name] = param_value
|
||||||
|
if start_pos is None:
|
||||||
|
start_pos = tc_match.start()
|
||||||
|
matches.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": func_name,
|
||||||
|
"arguments": arguments
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return matches, start_pos
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_deep_seek_tool_calls(answer: str, tool_names: list[str]):
|
||||||
|
"""Parse DeepSeek-style tool calls using fullwidth Unicode token delimiters.
|
||||||
|
|
||||||
|
Format:
|
||||||
|
<|tool▁calls▁begin|><|tool▁call▁begin|>func_name<|tool▁sep|>{"arg": "value"}<|tool▁call▁end|><|tool▁calls▁end|>
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
start_pos = None
|
||||||
|
for m in re.finditer(
|
||||||
|
r'<|tool▁call▁begin|>\s*(\S+?)\s*<|tool▁sep|>\s*',
|
||||||
|
answer
|
||||||
|
):
|
||||||
|
func_name = m.group(1).strip()
|
||||||
|
if func_name not in tool_names:
|
||||||
|
continue
|
||||||
|
json_str = _extract_balanced_json(answer, m.end())
|
||||||
|
if json_str is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
arguments = json.loads(json_str)
|
||||||
|
if start_pos is None:
|
||||||
|
# Check for section begin marker before the call marker
|
||||||
|
section = answer.rfind('<|tool▁calls▁begin|>', 0, m.start())
|
||||||
|
start_pos = section if section != -1 else m.start()
|
||||||
|
matches.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": func_name,
|
||||||
|
"arguments": arguments
|
||||||
|
}
|
||||||
|
})
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
return matches, start_pos
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_glm_tool_calls(answer: str, tool_names: list[str]):
|
||||||
|
"""Parse GLM-style tool calls using arg_key/arg_value XML pairs.
|
||||||
|
|
||||||
|
Format:
|
||||||
|
<tool_call>function_name
|
||||||
|
<arg_key>key1</arg_key>
|
||||||
|
<arg_value>value1</arg_value>
|
||||||
|
</tool_call>
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
start_pos = None
|
||||||
|
for tc_match in re.finditer(r'<tool_call>\s*(.*?)\s*</tool_call>', answer, re.DOTALL):
|
||||||
|
tc_content = tc_match.group(1)
|
||||||
|
# First non-tag text is the function name
|
||||||
|
name_match = re.match(r'([^<\s]+)', tc_content.strip())
|
||||||
|
if not name_match:
|
||||||
|
continue
|
||||||
|
func_name = name_match.group(1).strip()
|
||||||
|
if func_name not in tool_names:
|
||||||
|
continue
|
||||||
|
# Extract arg_key/arg_value pairs
|
||||||
|
keys = [k.group(1).strip() for k in re.finditer(r'<arg_key>\s*(.*?)\s*</arg_key>', tc_content, re.DOTALL)]
|
||||||
|
vals = [v.group(1).strip() for v in re.finditer(r'<arg_value>\s*(.*?)\s*</arg_value>', tc_content, re.DOTALL)]
|
||||||
|
if len(keys) != len(vals):
|
||||||
|
continue
|
||||||
|
arguments = {}
|
||||||
|
for k, v in zip(keys, vals):
|
||||||
|
try:
|
||||||
|
v = json.loads(v)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
pass # keep as string
|
||||||
|
arguments[k] = v
|
||||||
|
if start_pos is None:
|
||||||
|
start_pos = tc_match.start()
|
||||||
|
matches.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": func_name,
|
||||||
|
"arguments": arguments
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return matches, start_pos
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_pythonic_tool_calls(answer: str, tool_names: list[str]):
|
||||||
|
"""Parse pythonic-style tool calls used by Llama 4 and similar models.
|
||||||
|
|
||||||
|
Format:
|
||||||
|
[func_name(param1="value1", param2="value2"), func_name2(...)]
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
start_pos = None
|
||||||
|
# Match a bracketed list of function calls
|
||||||
|
bracket_match = re.search(r'\[([^\[\]]+)\]', answer)
|
||||||
|
if not bracket_match:
|
||||||
|
return matches, start_pos
|
||||||
|
|
||||||
|
inner = bracket_match.group(1)
|
||||||
|
|
||||||
|
# Build pattern for known tool names
|
||||||
|
escaped_names = [re.escape(name) for name in tool_names]
|
||||||
|
name_pattern = '|'.join(escaped_names)
|
||||||
|
|
||||||
|
for call_match in re.finditer(
|
||||||
|
r'(' + name_pattern + r')\(([^)]*)\)',
|
||||||
|
inner
|
||||||
|
):
|
||||||
|
func_name = call_match.group(1)
|
||||||
|
params_str = call_match.group(2).strip()
|
||||||
|
arguments = {}
|
||||||
|
|
||||||
|
if params_str:
|
||||||
|
# Parse key="value" pairs, handling commas inside quoted values
|
||||||
|
for param_match in re.finditer(
|
||||||
|
r'(\w+)\s*=\s*("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\)]+)',
|
||||||
|
params_str
|
||||||
|
):
|
||||||
|
param_name = param_match.group(1)
|
||||||
|
param_value = param_match.group(2).strip()
|
||||||
|
# Strip surrounding quotes
|
||||||
|
if (param_value.startswith('"') and param_value.endswith('"')) or \
|
||||||
|
(param_value.startswith("'") and param_value.endswith("'")):
|
||||||
|
param_value = param_value[1:-1]
|
||||||
|
# Try to parse as JSON for numeric/bool/null values
|
||||||
|
try:
|
||||||
|
param_value = json.loads(param_value)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
pass
|
||||||
|
arguments[param_name] = param_value
|
||||||
|
|
||||||
|
if start_pos is None:
|
||||||
|
start_pos = bracket_match.start()
|
||||||
|
matches.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": func_name,
|
||||||
|
"arguments": arguments
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return matches, start_pos
|
||||||
|
|
||||||
|
|
||||||
|
# Format registry: maps template substrings to the parser and streaming
|
||||||
|
# markers for that format. When a format's hints are NOT found in the
|
||||||
|
# template, its parser and markers are excluded.
|
||||||
|
TOOL_CALL_FORMATS = [
|
||||||
|
{
|
||||||
|
'template_hints': ['tool▁call▁begin', 'tool▁calls▁begin'],
|
||||||
|
'parser': _parse_deep_seek_tool_calls,
|
||||||
|
'markers': ['<|tool▁call▁begin|>', '<|tool▁calls▁begin|>'],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'template_hints': ['<|tool_call_begin|>', 'tool_calls_section'],
|
||||||
|
'parser': _parse_kimi_tool_calls,
|
||||||
|
'markers': ['<|tool_call_begin|>', '<|tool_calls_section_begin|>'],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'template_hints': ['to=functions.', '<|channel|>'],
|
||||||
|
'parser': _parse_channel_tool_calls,
|
||||||
|
'markers': ['to=functions.', '<|channel|>commentary'],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'template_hints': ['minimax:tool_call'],
|
||||||
|
'parser': _parse_minimax_tool_calls,
|
||||||
|
'markers': ['<minimax:tool_call>'],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'template_hints': ['<arg_key>'],
|
||||||
|
'parser': _parse_glm_tool_calls,
|
||||||
|
'markers': ['<tool_call>'],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'template_hints': ['<tool_call>'],
|
||||||
|
'parser': _parse_xml_param_tool_calls,
|
||||||
|
'markers': ['<tool_call>'],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'template_hints': ['[TOOL_CALLS]'],
|
||||||
|
'parser': _parse_mistral_token_tool_calls,
|
||||||
|
'markers': ['[TOOL_CALLS]'],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'template_hints': ['<function_call>'],
|
||||||
|
'parser': None,
|
||||||
|
'markers': ['<function_call>'],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Default ordered list of all specialized parsers.
|
||||||
|
ALL_PARSERS = [
|
||||||
|
_parse_deep_seek_tool_calls,
|
||||||
|
_parse_kimi_tool_calls,
|
||||||
|
_parse_channel_tool_calls,
|
||||||
|
_parse_minimax_tool_calls,
|
||||||
|
_parse_glm_tool_calls,
|
||||||
|
_parse_xml_param_tool_calls,
|
||||||
|
_parse_mistral_token_tool_calls,
|
||||||
|
_parse_bare_name_tool_calls,
|
||||||
|
_parse_pythonic_tool_calls,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def detect_tool_call_format(template_str):
|
||||||
|
"""Inspect a chat/instruction template to determine which tool call
|
||||||
|
formats are relevant.
|
||||||
|
|
||||||
|
Uses an exclude-based approach: starts with all parsers/markers,
|
||||||
|
then removes the ones whose hints are not found in the template.
|
||||||
|
|
||||||
|
Returns (parsers, streaming_markers, check_bare_names).
|
||||||
|
"""
|
||||||
|
if not template_str:
|
||||||
|
return None, TOOL_CALL_OPENING_MARKERS, True
|
||||||
|
|
||||||
|
matched_any = False
|
||||||
|
exclude_parsers = []
|
||||||
|
exclude_markers = []
|
||||||
|
matched_markers = []
|
||||||
|
|
||||||
|
for fmt in TOOL_CALL_FORMATS:
|
||||||
|
if any(hint in template_str for hint in fmt['template_hints']):
|
||||||
|
matched_any = True
|
||||||
|
matched_markers.extend(fmt['markers'])
|
||||||
|
else:
|
||||||
|
if fmt['parser'] is not None:
|
||||||
|
exclude_parsers.append(fmt['parser'])
|
||||||
|
exclude_markers.extend(fmt['markers'])
|
||||||
|
|
||||||
|
if not matched_any:
|
||||||
|
return None, TOOL_CALL_OPENING_MARKERS, True
|
||||||
|
|
||||||
|
parsers = [p for p in ALL_PARSERS if p not in exclude_parsers]
|
||||||
|
markers = [m for m in TOOL_CALL_OPENING_MARKERS if m not in exclude_markers or m in matched_markers]
|
||||||
|
|
||||||
|
return parsers, markers, False
|
||||||
|
|
||||||
|
|
||||||
|
def parse_tool_call(answer: str, tool_names: list[str], return_prefix: bool = False, parsers: list = None):
|
||||||
|
matches = []
|
||||||
|
start_pos = None
|
||||||
|
|
||||||
|
def _return(matches, start_pos):
|
||||||
|
if return_prefix:
|
||||||
|
prefix = answer[:start_pos] if matches and start_pos is not None else ''
|
||||||
|
return matches, prefix
|
||||||
|
return matches
|
||||||
|
|
||||||
|
# Try specialized parsers.
|
||||||
|
for parser in (parsers if parsers is not None else ALL_PARSERS):
|
||||||
|
matches, start_pos = parser(answer, tool_names)
|
||||||
|
if matches:
|
||||||
|
return _return(matches, start_pos)
|
||||||
|
|
||||||
|
# Generic fallback: regex pattern to find the JSON content wrapped in <function>, <tools>, <tool_call>, and other tags observed from various models
|
||||||
|
patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)</\1>"]
|
||||||
|
|
||||||
|
for pattern in patterns:
|
||||||
|
for match in re.finditer(pattern, answer, re.DOTALL):
|
||||||
|
if match.group(2) is None:
|
||||||
|
continue
|
||||||
|
# remove backtick wraps if present
|
||||||
|
candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip())
|
||||||
|
candidate = re.sub(r"```$", "", candidate.strip())
|
||||||
|
# unwrap inner tags
|
||||||
|
candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL)
|
||||||
|
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
|
||||||
|
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
|
||||||
|
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
|
||||||
|
if not candidate.strip().startswith("["):
|
||||||
|
candidate = "[" + candidate + "]"
|
||||||
|
|
||||||
|
candidates = []
|
||||||
|
try:
|
||||||
|
# parse the candidate JSON into a dictionary
|
||||||
|
candidates = json.loads(candidate)
|
||||||
|
if not isinstance(candidates, list):
|
||||||
|
candidates = [candidates]
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Ignore invalid JSON silently
|
||||||
|
continue
|
||||||
|
|
||||||
|
for candidate_dict in candidates:
|
||||||
|
checked_candidate = check_and_sanitize_tool_call_candidate(candidate_dict, tool_names)
|
||||||
|
if checked_candidate is not None:
|
||||||
|
if start_pos is None:
|
||||||
|
start_pos = match.start()
|
||||||
|
matches.append(checked_candidate)
|
||||||
|
|
||||||
|
# last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags
|
||||||
|
if len(matches) == 0:
|
||||||
|
try:
|
||||||
|
candidate = answer
|
||||||
|
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
|
||||||
|
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
|
||||||
|
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
|
||||||
|
if not candidate.strip().startswith("["):
|
||||||
|
candidate = "[" + candidate + "]"
|
||||||
|
# parse the candidate JSON into a dictionary
|
||||||
|
candidates = json.loads(candidate)
|
||||||
|
if not isinstance(candidates, list):
|
||||||
|
candidates = [candidates]
|
||||||
|
for candidate_dict in candidates:
|
||||||
|
checked_candidate = check_and_sanitize_tool_call_candidate(candidate_dict, tool_names)
|
||||||
|
if checked_candidate is not None:
|
||||||
|
matches.append(checked_candidate)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Ignore invalid JSON silently
|
||||||
|
pass
|
||||||
|
|
||||||
|
return _return(matches, start_pos)
|
||||||
71
modules/tool_use.py
Normal file
71
modules/tool_use.py
Normal file
|
|
@ -0,0 +1,71 @@
|
||||||
|
import importlib.util
|
||||||
|
import json
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
from modules.utils import natural_keys, sanitize_filename
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_tools():
|
||||||
|
"""Return sorted list of tool script names from user_data/tools/*.py."""
|
||||||
|
tools_dir = shared.user_data_dir / 'tools'
|
||||||
|
tools_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
return sorted((p.stem for p in tools_dir.glob('*.py')), key=natural_keys)
|
||||||
|
|
||||||
|
|
||||||
|
def load_tools(selected_names):
|
||||||
|
"""
|
||||||
|
Import selected tool scripts and return their definitions and executors.
|
||||||
|
Returns (tool_defs, executors) where:
|
||||||
|
- tool_defs: list of OpenAI-format tool dicts
|
||||||
|
- executors: dict mapping function_name -> execute callable
|
||||||
|
"""
|
||||||
|
tool_defs = []
|
||||||
|
executors = {}
|
||||||
|
for name in selected_names:
|
||||||
|
name = sanitize_filename(name)
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
path = shared.user_data_dir / 'tools' / f'{name}.py'
|
||||||
|
if not path.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
spec = importlib.util.spec_from_file_location(f"tool_{name}", str(path))
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f'Failed to load tool script "{name}"')
|
||||||
|
continue
|
||||||
|
|
||||||
|
tool_def = getattr(module, 'tool', None)
|
||||||
|
execute_fn = getattr(module, 'execute', None)
|
||||||
|
if tool_def is None or execute_fn is None:
|
||||||
|
logger.warning(f'Tool "{name}" is missing a "tool" dict or "execute" function.')
|
||||||
|
continue
|
||||||
|
|
||||||
|
func_name = tool_def.get('function', {}).get('name', name)
|
||||||
|
if func_name in executors:
|
||||||
|
logger.warning(f'Tool "{name}" declares function name "{func_name}" which conflicts with an already loaded tool. Skipping.')
|
||||||
|
continue
|
||||||
|
tool_defs.append(tool_def)
|
||||||
|
executors[func_name] = execute_fn
|
||||||
|
|
||||||
|
return tool_defs, executors
|
||||||
|
|
||||||
|
|
||||||
|
def execute_tool(func_name, arguments, executors):
|
||||||
|
"""Execute a tool by function name. Returns result as a JSON string."""
|
||||||
|
fn = executors.get(func_name)
|
||||||
|
if fn is None:
|
||||||
|
return json.dumps({"error": f"Unknown tool: {func_name}"})
|
||||||
|
|
||||||
|
try:
|
||||||
|
if isinstance(arguments, str):
|
||||||
|
arguments = json.loads(arguments)
|
||||||
|
result = fn(arguments)
|
||||||
|
return json.dumps(result) if not isinstance(result, str) else result
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f'Tool "{func_name}" execution failed')
|
||||||
|
return json.dumps({"error": str(e)})
|
||||||
|
|
@ -12,9 +12,6 @@ def get_device():
|
||||||
return shared.model.device
|
return shared.model.device
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
return torch.device('cuda')
|
return torch.device('cuda')
|
||||||
elif shared.args.deepspeed:
|
|
||||||
import deepspeed
|
|
||||||
return deepspeed.get_accelerator().current_device_name()
|
|
||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
return torch.device('mps')
|
return torch.device('mps')
|
||||||
elif is_torch_xpu_available():
|
elif is_torch_xpu_available():
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ import traceback
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import shared, ui, utils
|
from modules import shared, ui, utils
|
||||||
|
|
@ -24,9 +25,8 @@ from modules.evaluate import (
|
||||||
)
|
)
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.models import reload_model
|
from modules.models import reload_model
|
||||||
from modules.utils import natural_keys
|
|
||||||
|
|
||||||
PARAMETERS = ["lora_name", "always_override", "q_proj_en", "v_proj_en", "k_proj_en", "o_proj_en", "gate_proj_en", "down_proj_en", "up_proj_en", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to"]
|
PARAMETERS = ["lora_name", "always_override", "all_linear", "q_proj_en", "v_proj_en", "k_proj_en", "o_proj_en", "gate_proj_en", "down_proj_en", "up_proj_en", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "text_dataset", "higher_rank_limit", "warmup_steps", "optimizer", "stride_length", "stop_at_loss", "add_eos_token", "excess_length", "report_to"]
|
||||||
WANT_INTERRUPT = False
|
WANT_INTERRUPT = False
|
||||||
|
|
||||||
train_log = {}
|
train_log = {}
|
||||||
|
|
@ -53,7 +53,8 @@ def create_ui():
|
||||||
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).', elem_classes=['no-background'])
|
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).', elem_classes=['no-background'])
|
||||||
|
|
||||||
with gr.Accordion(label='Target Modules', open=False, elem_classes='tgw-accordion'):
|
with gr.Accordion(label='Target Modules', open=False, elem_classes='tgw-accordion'):
|
||||||
gr.Markdown("Selects which modules to target in training. Targeting more modules is closer to a full fine-tune at the cost of increased VRAM requirements and adapter size.\nNOTE: Only works for model_id='llama', other types will retain default training behavior and not use these settings.")
|
gr.Markdown("Selects which modules to target in training. Targeting more modules is closer to a full fine-tune at the cost of increased VRAM and adapter size.")
|
||||||
|
all_linear = gr.Checkbox(label='Target all linear layers', value=True, info='Targets every nn.Linear layer except lm_head. Works for any model architecture. When checked, the individual module checkboxes below are ignored.', elem_classes=['no-background'])
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
q_proj_en = gr.Checkbox(label='Enable q_proj', value=True)
|
q_proj_en = gr.Checkbox(label='Enable q_proj', value=True)
|
||||||
|
|
@ -72,67 +73,60 @@ def create_ui():
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='Also called dimension count. Higher values = larger file, more content control. Smaller values = smaller file, less control. Use 4 or 8 for style, 128 or 256 to teach, 1024+ for fine-detail on big data. More VRAM is needed for higher ranks.')
|
lora_rank = gr.Slider(label='LoRA Rank', value=8, minimum=0, maximum=1024, step=4, info='Also called dimension count. Higher values = larger file, more content control. Smaller values = smaller file, less control. Use 4 or 8 for style, 128 or 256 to teach, 1024+ for fine-detail on big data. More VRAM is needed for higher ranks.')
|
||||||
lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
|
lora_alpha = gr.Slider(label='LoRA Alpha', value=16, minimum=0, maximum=2048, step=4, info='This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
|
||||||
batch_size = gr.Slider(label='Batch Size', value=128, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.')
|
batch_size = gr.Slider(label='Batch Size', value=32, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.')
|
||||||
micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.')
|
micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.')
|
||||||
cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=4096, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.')
|
cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=4096, value=512, step=32, info='Maximum sequence length in tokens. For instruction datasets, conversations longer than this are dropped. For text datasets, documents are split into chunks of this size. Higher values require more VRAM.')
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a checkpoint of the LoRA will be saved every time this many steps pass.')
|
save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a full training checkpoint (adapter weights, optimizer, scheduler) will be saved every time this many steps pass. Training can be resumed from these checkpoints.')
|
||||||
|
|
||||||
epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.')
|
epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.')
|
||||||
learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='In scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.')
|
learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='In scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='linear', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt'], info='Learning rate scheduler - defines how the learning rate changes over time. "Constant" means never change, "linear" means to go in a straight line from the learning rate down to 0, cosine follows a curve, etc.', elem_classes=['slim-dropdown'])
|
lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='cosine', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt'], info='Learning rate scheduler - defines how the learning rate changes over time. "Constant" means never change, "linear" means to go in a straight line from the learning rate down to 0, cosine follows a curve, etc.', elem_classes=['slim-dropdown'])
|
||||||
|
|
||||||
with gr.Accordion(label='Advanced Options', open=False, elem_classes='tgw-accordion'):
|
with gr.Accordion(label='Advanced Options', open=False, elem_classes='tgw-accordion'):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.')
|
lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.0, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.')
|
||||||
stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)')
|
stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.', elem_classes=['slim-dropdown'])
|
optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Optimizer algorithm. adamw_torch is the standard choice. adamw_bnb_8bit uses less VRAM. adafactor is memory-efficient for large models.', elem_classes=['slim-dropdown'])
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate will be lower than normal. This helps the trainer prepare the model and precompute statistics to improve the quality of training after the start.')
|
warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate is gradually ramped up from 0 to the target value. This prevents unstable updates early in training.')
|
||||||
train_only_after = gr.Textbox(label='Train Only After', value='', info='Only consider text *after* this string in any given chunk for training. For Alpaca datasets, use "### Response:" to only train the response and ignore the input.')
|
|
||||||
|
|
||||||
add_eos_token = gr.Checkbox(label='Add EOS token', value=False, info="Adds EOS token for each dataset item. In case of raw text, the EOS will be added at the Hard Cut")
|
add_eos_token = gr.Checkbox(label='Add EOS token', value=True, info="Adds EOS token for each document in text datasets.")
|
||||||
|
excess_length = gr.Dropdown(label='Excess length', value='drop', choices=['drop', 'truncate'], info='What to do with conversations that exceed the cutoff length. "Drop" removes them entirely (recommended). "Truncate" cuts from the right, which may produce incomplete responses.', elem_classes=['slim-dropdown'])
|
||||||
|
|
||||||
higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
|
higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
|
||||||
report_to = gr.Radio(label="Save detailed logs with", value="None", choices=["None", "wandb", "tensorboard"], interactive=True)
|
report_to = gr.Radio(label="Save detailed logs with", value="None", choices=["None", "wandb", "tensorboard"], interactive=True)
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Tab(label='Formatted Dataset'):
|
with gr.Tab(label='Chat Dataset'):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
format = gr.Dropdown(choices=utils.get_datasets('user_data/training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.', elem_classes=['slim-dropdown'], interactive=not mu)
|
dataset = gr.Dropdown(choices=utils.get_chat_datasets(str(shared.user_data_dir / 'training/datasets')), value='None', label='Dataset File', info='A JSON file with chat conversations (messages or ShareGPT format). Each row is one conversation.', elem_classes=['slim-dropdown'], interactive=not mu)
|
||||||
ui.create_refresh_button(format, lambda: None, lambda: {'choices': utils.get_datasets('user_data/training/formats', 'json')}, 'refresh-button', interactive=not mu)
|
ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': utils.get_chat_datasets(str(shared.user_data_dir / 'training/datasets'))}, 'refresh-button', interactive=not mu)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
dataset = gr.Dropdown(choices=utils.get_datasets('user_data/training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.', elem_classes=['slim-dropdown'], interactive=not mu)
|
format = gr.Dropdown(choices=get_instruction_templates(), value='None', label='Instruction Template', info='Select an instruction template for formatting the dataset, or "Chat Template" to use the model\'s built-in chat template.', elem_classes=['slim-dropdown'], interactive=not mu)
|
||||||
ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': utils.get_datasets('user_data/training/datasets', 'json')}, 'refresh-button', interactive=not mu)
|
ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_instruction_templates()}, 'refresh-button', interactive=not mu)
|
||||||
|
|
||||||
|
with gr.Tab(label="Text Dataset"):
|
||||||
|
with gr.Row():
|
||||||
|
text_dataset = gr.Dropdown(choices=utils.get_text_datasets(str(shared.user_data_dir / 'training/datasets')), value='None', label='Dataset File', info='A JSON file with a "text" key per row, for pretraining-style training. Each row is one document.', elem_classes=['slim-dropdown'], interactive=not mu)
|
||||||
|
ui.create_refresh_button(text_dataset, lambda: None, lambda: {'choices': utils.get_text_datasets(str(shared.user_data_dir / 'training/datasets'))}, 'refresh-button', interactive=not mu)
|
||||||
|
|
||||||
|
stride_length = gr.Slider(label='Stride Length', minimum=0, maximum=2048, value=256, step=32, info='Overlap between chunks in tokens. 0 = no overlap. Values like 256 or 512 help preserve context across chunk boundaries.')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
eval_dataset = gr.Dropdown(choices=utils.get_datasets('user_data/training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.', elem_classes=['slim-dropdown'], interactive=not mu)
|
eval_dataset = gr.Dropdown(choices=utils.get_datasets(str(shared.user_data_dir / 'training/datasets'), 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.', elem_classes=['slim-dropdown'], interactive=not mu)
|
||||||
ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': utils.get_datasets('user_data/training/datasets', 'json')}, 'refresh-button', interactive=not mu)
|
ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': utils.get_datasets(str(shared.user_data_dir / 'training/datasets'), 'json')}, 'refresh-button', interactive=not mu)
|
||||||
|
|
||||||
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
|
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
|
||||||
|
|
||||||
with gr.Tab(label="Raw text file"):
|
|
||||||
with gr.Row():
|
|
||||||
raw_text_file = gr.Dropdown(choices=utils.get_datasets('user_data/training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.', elem_classes=['slim-dropdown'], interactive=not mu)
|
|
||||||
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': utils.get_datasets('user_data/training/datasets', 'txt')}, 'refresh-button', interactive=not mu)
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='How many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length). Setting overlap to exactly half the cutoff length may be ideal.')
|
|
||||||
newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
hard_cut_string = gr.Textbox(label='Hard Cut String', value='\\n\\n\\n', info='String that indicates a hard cut between text parts. Helps prevent unwanted overlap.')
|
|
||||||
min_chars = gr.Number(label='Ignore small blocks', value=0, info='Ignore Hard Cut blocks that have less or equal characters than this number')
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
start_button = gr.Button("Start LoRA Training", variant='primary', interactive=not mu)
|
start_button = gr.Button("Start LoRA Training", variant='primary', interactive=not mu)
|
||||||
stop_button = gr.Button("Interrupt", interactive=not mu)
|
stop_button = gr.Button("Interrupt", interactive=not mu)
|
||||||
|
|
@ -143,7 +137,7 @@ def create_ui():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
models = gr.Dropdown(utils.get_available_models(), label='Models', multiselect=True, interactive=not mu)
|
models = gr.Dropdown(utils.get_available_models(), label='Models', multiselect=True, interactive=not mu)
|
||||||
evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + utils.get_datasets('user_data/training/datasets', 'txt')[1:], value='wikitext', label='Input dataset', info='The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under user_data/training/datasets.', interactive=not mu)
|
evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + utils.get_datasets(str(shared.user_data_dir / 'training/datasets'), 'txt')[1:], value='wikitext', label='Input dataset', info=f'The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under {shared.user_data_dir}/training/datasets.', interactive=not mu)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
stride_length = gr.Slider(label='Stride', minimum=0, maximum=32768, value=512, step=256, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')
|
stride_length = gr.Slider(label='Stride', minimum=0, maximum=32768, value=512, step=256, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')
|
||||||
|
|
@ -165,7 +159,7 @@ def create_ui():
|
||||||
refresh_table = gr.Button('Refresh the table', elem_classes="small-button", interactive=not mu)
|
refresh_table = gr.Button('Refresh the table', elem_classes="small-button", interactive=not mu)
|
||||||
|
|
||||||
# Training events
|
# Training events
|
||||||
all_params = [lora_name, always_override, q_proj_en, v_proj_en, k_proj_en, o_proj_en, gate_proj_en, down_proj_en, up_proj_en, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to]
|
all_params = [lora_name, always_override, all_linear, q_proj_en, v_proj_en, k_proj_en, o_proj_en, gate_proj_en, down_proj_en, up_proj_en, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, text_dataset, higher_rank_limit, warmup_steps, optimizer, stride_length, stop_at_loss, add_eos_token, excess_length, report_to]
|
||||||
|
|
||||||
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
|
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
|
||||||
start_button.click(do_train, all_params, output)
|
start_button.click(do_train, all_params, output)
|
||||||
|
|
@ -229,9 +223,34 @@ def clean_path(base_path: str, path: str):
|
||||||
return f'{Path(base_path).absolute()}/{path}'
|
return f'{Path(base_path).absolute()}/{path}'
|
||||||
|
|
||||||
|
|
||||||
|
def get_instruction_templates():
|
||||||
|
path = shared.user_data_dir / 'instruction-templates'
|
||||||
|
names = set()
|
||||||
|
for ext in ['yaml', 'yml', 'jinja', 'jinja2']:
|
||||||
|
for f in path.glob(f'*.{ext}'):
|
||||||
|
names.add(f.stem)
|
||||||
|
return ['None', 'Chat Template'] + sorted(names, key=utils.natural_keys)
|
||||||
|
|
||||||
|
|
||||||
|
def load_template(name):
|
||||||
|
"""Load a Jinja2 template string from {user_data_dir}/instruction-templates/."""
|
||||||
|
path = shared.user_data_dir / 'instruction-templates'
|
||||||
|
for ext in ['jinja', 'jinja2', 'yaml', 'yml']:
|
||||||
|
filepath = path / f'{name}.{ext}'
|
||||||
|
if filepath.exists():
|
||||||
|
if ext in ['jinja', 'jinja2']:
|
||||||
|
return filepath.read_text(encoding='utf-8')
|
||||||
|
else:
|
||||||
|
data = yaml.safe_load(filepath.read_text(encoding='utf-8'))
|
||||||
|
return data.get('instruction_template', '')
|
||||||
|
return ''
|
||||||
|
|
||||||
|
|
||||||
def backup_adapter(input_folder):
|
def backup_adapter(input_folder):
|
||||||
# Get the creation date of the file adapter_model.bin
|
# Get the creation date of the adapter file (safetensors or bin)
|
||||||
try:
|
try:
|
||||||
|
adapter_file = Path(f"{input_folder}/adapter_model.safetensors")
|
||||||
|
if not adapter_file.is_file():
|
||||||
adapter_file = Path(f"{input_folder}/adapter_model.bin")
|
adapter_file = Path(f"{input_folder}/adapter_model.bin")
|
||||||
if adapter_file.is_file():
|
if adapter_file.is_file():
|
||||||
|
|
||||||
|
|
@ -244,7 +263,7 @@ def backup_adapter(input_folder):
|
||||||
subfolder_path.mkdir(parents=True, exist_ok=True)
|
subfolder_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Check if the file already exists in the subfolder
|
# Check if the file already exists in the subfolder
|
||||||
backup_adapter_file = Path(f"{input_folder}/{creation_date_str}/adapter_model.bin")
|
backup_adapter_file = subfolder_path / adapter_file.name
|
||||||
if backup_adapter_file.is_file():
|
if backup_adapter_file.is_file():
|
||||||
print(" - Backup already exists. Skipping backup process.")
|
print(" - Backup already exists. Skipping backup process.")
|
||||||
return
|
return
|
||||||
|
|
@ -274,7 +293,7 @@ def calc_trainable_parameters(model):
|
||||||
return trainable_params, all_param
|
return trainable_params, all_param
|
||||||
|
|
||||||
|
|
||||||
def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en: bool, k_proj_en: bool, o_proj_en: bool, gate_proj_en: bool, down_proj_en: bool, up_proj_en: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str):
|
def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: bool, v_proj_en: bool, k_proj_en: bool, o_proj_en: bool, gate_proj_en: bool, down_proj_en: bool, up_proj_en: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, text_dataset: str, higher_rank_limit: bool, warmup_steps: int, optimizer: str, stride_length: int, stop_at_loss: float, add_eos_token: bool, excess_length: str, report_to: str):
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
|
@ -285,21 +304,17 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||||
prepare_model_for_kbit_training,
|
prepare_model_for_kbit_training,
|
||||||
set_peft_model_state_dict
|
set_peft_model_state_dict
|
||||||
)
|
)
|
||||||
from peft.utils.other import \
|
|
||||||
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \
|
|
||||||
model_to_lora_modules
|
|
||||||
from transformers import is_torch_xpu_available
|
|
||||||
from transformers.models.auto.modeling_auto import (
|
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
|
||||||
)
|
|
||||||
|
|
||||||
MODEL_CLASSES = {v[1]: v[0] for v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.items()}
|
|
||||||
|
|
||||||
global WANT_INTERRUPT
|
global WANT_INTERRUPT
|
||||||
WANT_INTERRUPT = False
|
WANT_INTERRUPT = False
|
||||||
|
|
||||||
# == Input validation / processing ==
|
# == Input validation / processing ==
|
||||||
yield "Preparing the input..."
|
yield "Preparing the input..."
|
||||||
|
|
||||||
|
if shared.args.loader == 'llama.cpp':
|
||||||
|
yield "Error: LoRA training requires a model loaded with the Transformers loader. GGUF models are not supported for training."
|
||||||
|
return
|
||||||
|
|
||||||
lora_file_path = clean_path(None, lora_name)
|
lora_file_path = clean_path(None, lora_name)
|
||||||
if lora_file_path.strip() == '':
|
if lora_file_path.strip() == '':
|
||||||
yield "Missing or invalid LoRA file name input."
|
yield "Missing or invalid LoRA file name input."
|
||||||
|
|
@ -309,10 +324,6 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||||
actual_lr = float(learning_rate)
|
actual_lr = float(learning_rate)
|
||||||
model_type = type(shared.model).__name__
|
model_type = type(shared.model).__name__
|
||||||
|
|
||||||
if model_type in MODEL_CLASSES:
|
|
||||||
model_id = MODEL_CLASSES[model_type]
|
|
||||||
else:
|
|
||||||
model_id = "llama"
|
|
||||||
if model_type == "PeftModelForCausalLM":
|
if model_type == "PeftModelForCausalLM":
|
||||||
if len(shared.lora_names) > 0:
|
if len(shared.lora_names) > 0:
|
||||||
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||||
|
|
@ -320,9 +331,6 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||||
else:
|
else:
|
||||||
yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||||
logger.warning("Model ID not matched due to LoRA loading. Consider reloading base model.")
|
logger.warning("Model ID not matched due to LoRA loading. Consider reloading base model.")
|
||||||
else:
|
|
||||||
yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
|
||||||
logger.warning(f"LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
|
|
||||||
|
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
|
|
@ -330,166 +338,206 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||||
yield "Cannot input zeroes."
|
yield "Cannot input zeroes."
|
||||||
return
|
return
|
||||||
|
|
||||||
gradient_accumulation_steps = batch_size // micro_batch_size
|
gradient_accumulation_steps = max(1, batch_size // micro_batch_size)
|
||||||
shared.tokenizer.pad_token_id = 0
|
original_chat_template = getattr(shared.tokenizer, 'chat_template', None)
|
||||||
shared.tokenizer.padding_side = "left"
|
if shared.tokenizer.pad_token_id is None:
|
||||||
|
shared.tokenizer.pad_token_id = shared.tokenizer.eos_token_id
|
||||||
|
shared.tokenizer.padding_side = "right"
|
||||||
|
|
||||||
# Populate target_modules list with chosen X_proj modules. Llama-based models only atm, non-llama will revert to default behavior.
|
def list_target_modules():
|
||||||
def list_target_modules(model_id):
|
if all_linear:
|
||||||
if model_id != "llama" and model_id != "mistral":
|
return "all-linear"
|
||||||
return model_to_lora_modules[model_id]
|
|
||||||
|
|
||||||
available_modules = {
|
target_mods = [f"{name}_proj" for name, enabled in {
|
||||||
"gate": gate_proj_en,
|
"q": q_proj_en, "k": k_proj_en, "v": v_proj_en, "o": o_proj_en,
|
||||||
"down": down_proj_en,
|
"gate": gate_proj_en, "down": down_proj_en, "up": up_proj_en,
|
||||||
"up": up_proj_en,
|
}.items() if enabled]
|
||||||
"q": q_proj_en,
|
|
||||||
"v": v_proj_en,
|
|
||||||
"k": k_proj_en,
|
|
||||||
"o": o_proj_en,
|
|
||||||
}
|
|
||||||
target_mods = [f"{name}_proj" for name, enabled in available_modules.items() if enabled]
|
|
||||||
return target_mods
|
return target_mods
|
||||||
|
|
||||||
def encode(text, add_bos_token):
|
def normalize_messages(data_point):
|
||||||
result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len)
|
"""Convert a dataset row to OpenAI messages format for apply_chat_template()."""
|
||||||
# Check if the first two tokens are BOS
|
if "messages" in data_point:
|
||||||
if len(result) >= 2 and result[:2] == [shared.tokenizer.bos_token_id, shared.tokenizer.bos_token_id]:
|
return data_point["messages"]
|
||||||
result = result[1:]
|
|
||||||
|
|
||||||
if not add_bos_token and result[0] == shared.tokenizer.bos_token_id:
|
if "conversations" in data_point:
|
||||||
result = result[1:]
|
role_map = {"human": "user", "gpt": "assistant"}
|
||||||
return result
|
return [
|
||||||
|
{"role": role_map.get(turn.get("from", ""), turn.get("from", "")), "content": turn["value"]}
|
||||||
|
for turn in data_point["conversations"]
|
||||||
|
]
|
||||||
|
|
||||||
def tokenize(prompt, append_eos_token=False):
|
raise RuntimeError(
|
||||||
|
f'Dataset row must contain "messages" or "conversations" key. '
|
||||||
|
f'Found: {list(data_point.keys())}'
|
||||||
|
)
|
||||||
|
|
||||||
if train_only_after == '' or train_only_after not in prompt:
|
def tokenize_conversation(data_point):
|
||||||
input_ids = encode(prompt, True)
|
"""Tokenize using apply_chat_template() with assistant-only label masking."""
|
||||||
|
messages = normalize_messages(data_point)
|
||||||
|
full_ids = list(shared.tokenizer.apply_chat_template(messages, tokenize=True, return_dict=False))
|
||||||
|
|
||||||
if append_eos_token and input_ids[-1] != shared.tokenizer.eos_token_id and len(input_ids) < cutoff_len:
|
# Build labels: -100 for everything, then unmask assistant turns.
|
||||||
input_ids.append(shared.tokenizer.eos_token_id)
|
# This assumes apply_chat_template(messages[:i]) is a token-for-token
|
||||||
|
# prefix of apply_chat_template(messages[:i+1]), which holds for all
|
||||||
input_ids = [shared.tokenizer.pad_token_id] * (cutoff_len - len(input_ids)) + input_ids
|
# standard chat templates (Llama, ChatML, Mistral, etc.).
|
||||||
labels = [1] * len(input_ids)
|
labels = [-100] * len(full_ids)
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
if msg["role"] == "assistant":
|
||||||
|
# Tokens up to where this assistant turn starts
|
||||||
|
header_ids = shared.tokenizer.apply_chat_template(
|
||||||
|
messages[:i], tokenize=True, return_dict=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
# Tokens through end of this assistant turn
|
||||||
|
through_ids = shared.tokenizer.apply_chat_template(
|
||||||
|
messages[:i + 1], tokenize=True, return_dict=False
|
||||||
|
)
|
||||||
|
# Unmask assistant tokens
|
||||||
|
start = len(header_ids)
|
||||||
|
end = min(len(through_ids), len(full_ids))
|
||||||
|
labels[start:end] = full_ids[start:end]
|
||||||
|
|
||||||
|
if len(full_ids) > cutoff_len:
|
||||||
|
if excess_length == 'truncate':
|
||||||
|
full_ids = full_ids[:cutoff_len]
|
||||||
|
labels = labels[:cutoff_len]
|
||||||
else:
|
else:
|
||||||
ind = prompt.index(train_only_after) + len(train_only_after)
|
return {"input_ids": [], "labels": [], "attention_mask": []}
|
||||||
before_tokens = encode(prompt[:ind], True)
|
|
||||||
after_tokens = encode(prompt[ind:], False)
|
|
||||||
|
|
||||||
if append_eos_token and after_tokens[-1] != shared.tokenizer.eos_token_id:
|
|
||||||
after_tokens.append(shared.tokenizer.eos_token_id)
|
|
||||||
|
|
||||||
full_length = len(after_tokens) + len(before_tokens)
|
|
||||||
if full_length > cutoff_len:
|
|
||||||
after_tokens = after_tokens[:cutoff_len - len(before_tokens)]
|
|
||||||
else:
|
|
||||||
before_tokens = [shared.tokenizer.pad_token_id] * (cutoff_len - full_length) + before_tokens
|
|
||||||
|
|
||||||
input_ids = before_tokens + after_tokens
|
|
||||||
labels = [-100] * len(before_tokens) + [1] * len(after_tokens)
|
|
||||||
|
|
||||||
input_ids = torch.tensor(input_ids)
|
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": full_ids,
|
||||||
"labels": labels,
|
"labels": labels,
|
||||||
"attention_mask": input_ids.ne(shared.tokenizer.pad_token_id),
|
"attention_mask": [1] * len(full_ids),
|
||||||
}
|
}
|
||||||
|
|
||||||
train_template.clear()
|
train_template.clear()
|
||||||
|
|
||||||
# == Prep the dataset, format, etc ==
|
# == Prep the dataset, format, etc ==
|
||||||
if raw_text_file not in ['None', '']:
|
has_text_dataset = text_dataset not in ['None', '']
|
||||||
train_template["template_type"] = "raw_text"
|
has_chat_dataset = dataset not in ['None', '']
|
||||||
logger.info("Loading raw text file dataset")
|
if has_text_dataset and has_chat_dataset:
|
||||||
fullpath = clean_path('user_data/training/datasets', f'{raw_text_file}')
|
yield "Error: select either a Chat Dataset or a Text Dataset, not both."
|
||||||
fullpath = Path(fullpath)
|
return
|
||||||
if fullpath.is_dir():
|
|
||||||
logger.info('Training path directory {}'.format(raw_text_file))
|
|
||||||
raw_text = ""
|
|
||||||
file_paths = sorted(fullpath.glob('*.txt'), key=lambda path: natural_keys(path.name))
|
|
||||||
for file_path in file_paths:
|
|
||||||
if file_path.is_file():
|
|
||||||
with file_path.open('r', encoding='utf-8') as file:
|
|
||||||
raw_text += file.read().replace('\r', '')
|
|
||||||
|
|
||||||
logger.info(f"Loaded training file: {file_path.name}")
|
def tokenize_text_data(data):
|
||||||
else:
|
"""Tokenize text dataset rows, concatenate, and split into chunks."""
|
||||||
with open(clean_path('user_data/training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
|
all_tokens = []
|
||||||
raw_text = file.read().replace('\r', '')
|
for row in data:
|
||||||
|
tokens = shared.tokenizer.encode(row['text'])
|
||||||
cut_string = hard_cut_string.replace('\\n', '\n')
|
|
||||||
eos_added = 0
|
|
||||||
out_tokens = []
|
|
||||||
for text_part in raw_text.split(cut_string):
|
|
||||||
if len(text_part.strip()) <= min_chars:
|
|
||||||
continue
|
|
||||||
|
|
||||||
tokens = shared.tokenizer.encode(text_part)
|
|
||||||
if add_eos_token:
|
if add_eos_token:
|
||||||
tokens.append(shared.tokenizer.eos_token_id)
|
tokens.append(shared.tokenizer.eos_token_id)
|
||||||
eos_added += 1
|
all_tokens.extend(tokens)
|
||||||
|
|
||||||
|
stride = int(stride_length)
|
||||||
|
step = cutoff_len - stride if stride > 0 else cutoff_len
|
||||||
|
|
||||||
step = cutoff_len - overlap_len
|
|
||||||
if step <= 0:
|
if step <= 0:
|
||||||
yield f"Error: overlap_len ({overlap_len}) cannot be greater than or equal to cutoff_len ({cutoff_len})"
|
return None, "Error: stride length must be smaller than cutoff length."
|
||||||
return
|
if len(all_tokens) < cutoff_len:
|
||||||
|
return None, "Error: dataset is too short to fill even one chunk of the given cutoff length."
|
||||||
|
|
||||||
out_tokens.extend(split_chunks(tokens, cutoff_len, step))
|
chunks = []
|
||||||
|
for start in range(0, len(all_tokens), step):
|
||||||
if eos_added > 0:
|
chunk = all_tokens[start:start + cutoff_len]
|
||||||
print(f"EOS added to {eos_added} text blocks")
|
if len(chunk) == 0:
|
||||||
|
break
|
||||||
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
if len(chunk) < cutoff_len:
|
||||||
text_chunks = [shared.tokenizer.decode(x) for x in out_tokens]
|
pad_len = cutoff_len - len(chunk)
|
||||||
del out_tokens
|
chunks.append({
|
||||||
if newline_favor_len > 0:
|
"input_ids": chunk + [shared.tokenizer.pad_token_id] * pad_len,
|
||||||
text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks]
|
"labels": list(chunk) + [-100] * pad_len,
|
||||||
|
"attention_mask": [1] * len(chunk) + [0] * pad_len,
|
||||||
train_data = Dataset.from_list([tokenize(x) for x in text_chunks])
|
})
|
||||||
del text_chunks
|
|
||||||
eval_data = None
|
|
||||||
else:
|
else:
|
||||||
if dataset in ['None', '']:
|
chunks.append({
|
||||||
yield "Missing dataset choice input, cannot continue."
|
"input_ids": chunk,
|
||||||
|
"labels": list(chunk),
|
||||||
|
"attention_mask": [1] * cutoff_len,
|
||||||
|
})
|
||||||
|
|
||||||
|
return Dataset.from_list(chunks), None
|
||||||
|
|
||||||
|
if has_text_dataset:
|
||||||
|
train_template["template_type"] = "text_dataset"
|
||||||
|
logger.info("Loading text dataset")
|
||||||
|
data = load_dataset("json", data_files=clean_path(str(shared.user_data_dir / 'training/datasets'), f'{text_dataset}.json'))
|
||||||
|
|
||||||
|
if "text" not in data['train'].column_names:
|
||||||
|
yield "Error: text dataset must have a \"text\" key per row."
|
||||||
return
|
return
|
||||||
|
|
||||||
if format in ['None', '']:
|
train_data, err = tokenize_text_data(data['train'])
|
||||||
yield "Missing format choice input, cannot continue."
|
if err:
|
||||||
|
yield err
|
||||||
return
|
return
|
||||||
|
|
||||||
train_template["template_type"] = "dataset"
|
|
||||||
|
|
||||||
with open(clean_path('user_data/training/formats', f'{format}.json'), 'r', encoding='utf-8-sig') as formatFile:
|
|
||||||
format_data: dict[str, str] = json.load(formatFile)
|
|
||||||
|
|
||||||
# == store training prompt ==
|
|
||||||
for _, value in format_data.items():
|
|
||||||
prompt_key = f"template_{len(train_template)}"
|
|
||||||
train_template[prompt_key] = value
|
|
||||||
|
|
||||||
def generate_prompt(data_point: dict[str, str]):
|
|
||||||
for options, data in format_data.items():
|
|
||||||
if set(options.split(',')) == set(x[0] for x in data_point.items() if (type(x[1]) is str and len(x[1].strip()) > 0)):
|
|
||||||
for key, val in data_point.items():
|
|
||||||
if type(val) is str:
|
|
||||||
data = data.replace(f'%{key}%', val)
|
|
||||||
return data
|
|
||||||
raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
|
|
||||||
|
|
||||||
def generate_and_tokenize_prompt(data_point):
|
|
||||||
prompt = generate_prompt(data_point)
|
|
||||||
return tokenize(prompt, add_eos_token)
|
|
||||||
|
|
||||||
logger.info("Loading JSON datasets")
|
|
||||||
data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{dataset}.json'))
|
|
||||||
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
|
|
||||||
|
|
||||||
if eval_dataset == 'None':
|
if eval_dataset == 'None':
|
||||||
eval_data = None
|
eval_data = None
|
||||||
else:
|
else:
|
||||||
eval_data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{eval_dataset}.json'))
|
eval_raw = load_dataset("json", data_files=clean_path(str(shared.user_data_dir / 'training/datasets'), f'{eval_dataset}.json'))
|
||||||
eval_data = eval_data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
|
if "text" not in eval_raw['train'].column_names:
|
||||||
|
yield "Error: evaluation dataset must have a \"text\" key per row."
|
||||||
|
return
|
||||||
|
eval_data, err = tokenize_text_data(eval_raw['train'])
|
||||||
|
if err:
|
||||||
|
yield err
|
||||||
|
return
|
||||||
|
elif has_chat_dataset:
|
||||||
|
if format in ['None', '']:
|
||||||
|
yield "Missing format choice input, cannot continue."
|
||||||
|
return
|
||||||
|
|
||||||
|
if format == 'Chat Template':
|
||||||
|
if not getattr(shared.tokenizer, 'chat_template', None):
|
||||||
|
yield "Error: this model's tokenizer does not have a chat template. Select an instruction template instead, or load an instruct/chat model."
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# Load custom instruction template and set on tokenizer
|
||||||
|
template_str = load_template(format)
|
||||||
|
if not template_str:
|
||||||
|
yield f"Error: could not load instruction template '{format}'."
|
||||||
|
return
|
||||||
|
shared.tokenizer.chat_template = template_str
|
||||||
|
|
||||||
|
# Unified path — both cases use tokenize_conversation()
|
||||||
|
train_template["template_type"] = "chat_template"
|
||||||
|
|
||||||
|
logger.info("Loading JSON dataset with chat template format")
|
||||||
|
data = load_dataset("json", data_files=clean_path(str(shared.user_data_dir / 'training/datasets'), f'{dataset}.json'))
|
||||||
|
|
||||||
|
# Validate the first row
|
||||||
|
try:
|
||||||
|
normalize_messages(data['train'][0])
|
||||||
|
except (RuntimeError, KeyError, IndexError) as e:
|
||||||
|
yield f"Error: {e}"
|
||||||
|
return
|
||||||
|
|
||||||
|
total = len(data['train'])
|
||||||
|
train_data = data['train'].map(
|
||||||
|
tokenize_conversation,
|
||||||
|
remove_columns=data['train'].column_names,
|
||||||
|
new_fingerprint='%030x' % random.randrange(16**30)
|
||||||
|
)
|
||||||
|
train_data = train_data.filter(lambda x: len(x['input_ids']) > 0)
|
||||||
|
dropped = total - len(train_data)
|
||||||
|
if dropped > 0:
|
||||||
|
logger.warning(f"Dropped {dropped}/{total} conversations exceeding cutoff length of {cutoff_len} tokens.")
|
||||||
|
if len(train_data) == 0:
|
||||||
|
yield f"Error: all {total} conversations exceed the cutoff length of {cutoff_len} tokens. Increase the cutoff length or shorten your data."
|
||||||
|
return
|
||||||
|
|
||||||
|
if eval_dataset == 'None':
|
||||||
|
eval_data = None
|
||||||
|
else:
|
||||||
|
eval_data = load_dataset("json", data_files=clean_path(str(shared.user_data_dir / 'training/datasets'), f'{eval_dataset}.json'))
|
||||||
|
eval_data = eval_data['train'].map(
|
||||||
|
tokenize_conversation,
|
||||||
|
remove_columns=eval_data['train'].column_names,
|
||||||
|
new_fingerprint='%030x' % random.randrange(16**30)
|
||||||
|
)
|
||||||
|
eval_data = eval_data.filter(lambda x: len(x['input_ids']) > 0)
|
||||||
|
else:
|
||||||
|
yield "No dataset selected. Choose a Chat Dataset or a Text Dataset."
|
||||||
|
return
|
||||||
|
|
||||||
# == We MUST reload model if it went through any previous training, even failed one ==
|
# == We MUST reload model if it went through any previous training, even failed one ==
|
||||||
if shared.model_dirty_from_training:
|
if shared.model_dirty_from_training:
|
||||||
|
|
@ -502,12 +550,14 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||||
if shared.model is not None:
|
if shared.model is not None:
|
||||||
print("Model reloaded OK, continue with training.")
|
print("Model reloaded OK, continue with training.")
|
||||||
else:
|
else:
|
||||||
return f"Failed to load {selected_model}."
|
yield f"Failed to load {selected_model}."
|
||||||
except:
|
return
|
||||||
|
except Exception:
|
||||||
exc = traceback.format_exc()
|
exc = traceback.format_exc()
|
||||||
logger.error('Failed to reload the model.')
|
logger.error('Failed to reload the model.')
|
||||||
print(exc)
|
print(exc)
|
||||||
return exc.replace('\n', '\n\n')
|
yield exc.replace('\n', '\n\n')
|
||||||
|
return
|
||||||
|
|
||||||
# == Start prepping the model itself ==
|
# == Start prepping the model itself ==
|
||||||
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
||||||
|
|
@ -519,10 +569,15 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||||
shared.model_dirty_from_training = True
|
shared.model_dirty_from_training = True
|
||||||
|
|
||||||
logger.info("Preparing for training")
|
logger.info("Preparing for training")
|
||||||
|
target_modules = list_target_modules()
|
||||||
|
if not target_modules:
|
||||||
|
yield "No target modules selected. Enable at least one module or check 'Target all linear layers'."
|
||||||
|
return
|
||||||
|
|
||||||
config = LoraConfig(
|
config = LoraConfig(
|
||||||
r=lora_rank,
|
r=lora_rank,
|
||||||
lora_alpha=lora_alpha,
|
lora_alpha=lora_alpha,
|
||||||
target_modules=list_target_modules(model_id),
|
target_modules=target_modules,
|
||||||
lora_dropout=lora_dropout,
|
lora_dropout=lora_dropout,
|
||||||
bias="none",
|
bias="none",
|
||||||
task_type="CAUSAL_LM"
|
task_type="CAUSAL_LM"
|
||||||
|
|
@ -535,14 +590,31 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||||
# == get model trainable params
|
# == get model trainable params
|
||||||
model_trainable_params, model_all_params = calc_trainable_parameters(shared.model)
|
model_trainable_params, model_all_params = calc_trainable_parameters(shared.model)
|
||||||
|
|
||||||
|
# == Determine if we can resume from a checkpoint ==
|
||||||
|
resume_checkpoint = None
|
||||||
try:
|
try:
|
||||||
logger.info("Creating LoRA model")
|
logger.info("Creating LoRA model")
|
||||||
lora_model = get_peft_model(shared.model, config)
|
lora_model = get_peft_model(shared.model, config)
|
||||||
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
|
if not always_override and Path(lora_file_path).exists():
|
||||||
logger.info("Loading existing LoRA data")
|
# Look for HF Trainer checkpoint dirs (full resumption)
|
||||||
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin", weights_only=True)
|
checkpoints = sorted(Path(lora_file_path).glob("checkpoint-*"), key=os.path.getmtime)
|
||||||
|
if checkpoints:
|
||||||
|
resume_checkpoint = str(checkpoints[-1])
|
||||||
|
logger.info(f"Will resume from checkpoint: {resume_checkpoint}")
|
||||||
|
else:
|
||||||
|
# Legacy fallback: load bare adapter weights only
|
||||||
|
safetensors_path = Path(f"{lora_file_path}/adapter_model.safetensors")
|
||||||
|
bin_path = Path(f"{lora_file_path}/adapter_model.bin")
|
||||||
|
if safetensors_path.is_file():
|
||||||
|
logger.info("Loading existing LoRA data (safetensors)")
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
state_dict_peft = load_file(str(safetensors_path))
|
||||||
set_peft_model_state_dict(lora_model, state_dict_peft)
|
set_peft_model_state_dict(lora_model, state_dict_peft)
|
||||||
except:
|
elif bin_path.is_file():
|
||||||
|
logger.info("Loading existing LoRA data (bin)")
|
||||||
|
state_dict_peft = torch.load(str(bin_path), weights_only=True)
|
||||||
|
set_peft_model_state_dict(lora_model, state_dict_peft)
|
||||||
|
except Exception:
|
||||||
yield traceback.format_exc().replace('\n', '\n\n')
|
yield traceback.format_exc().replace('\n', '\n\n')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -562,14 +634,6 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||||
if WANT_INTERRUPT:
|
if WANT_INTERRUPT:
|
||||||
control.should_epoch_stop = True
|
control.should_epoch_stop = True
|
||||||
control.should_training_stop = True
|
control.should_training_stop = True
|
||||||
elif state.global_step > 0 and actual_save_steps > 0 and state.global_step % actual_save_steps == 0:
|
|
||||||
lora_model.save_pretrained(f"{lora_file_path}/checkpoint-{tracked.current_steps}/")
|
|
||||||
# Save log
|
|
||||||
with open(f"{lora_file_path}/checkpoint-{tracked.current_steps}/training_log.json", 'w', encoding='utf-8') as file:
|
|
||||||
json.dump(train_log, file, indent=2)
|
|
||||||
# == Save training prompt ==
|
|
||||||
with open(f"{lora_file_path}/checkpoint-{tracked.current_steps}/training_prompt.json", 'w', encoding='utf-8') as file:
|
|
||||||
json.dump(train_template, file, indent=2)
|
|
||||||
|
|
||||||
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||||
tracked.current_steps += 1
|
tracked.current_steps += 1
|
||||||
|
|
@ -586,22 +650,46 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||||
print(f"\033[1;30;40mStep: {tracked.current_steps} \033[0;37;0m", end='')
|
print(f"\033[1;30;40mStep: {tracked.current_steps} \033[0;37;0m", end='')
|
||||||
if 'loss' in logs:
|
if 'loss' in logs:
|
||||||
loss = float(logs['loss'])
|
loss = float(logs['loss'])
|
||||||
if loss <= stop_at_loss:
|
if stop_at_loss > 0 and loss <= stop_at_loss:
|
||||||
control.should_epoch_stop = True
|
control.should_epoch_stop = True
|
||||||
control.should_training_stop = True
|
control.should_training_stop = True
|
||||||
print(f"\033[1;31;1mStop Loss {stop_at_loss} reached.\033[0;37;0m")
|
print(f"\033[1;31;1mStop Loss {stop_at_loss} reached.\033[0;37;0m")
|
||||||
|
|
||||||
|
def on_save(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||||
|
checkpoint_dir = Path(args.output_dir) / f"checkpoint-{state.global_step}"
|
||||||
|
if checkpoint_dir.exists():
|
||||||
|
with open(checkpoint_dir / "training_log.json", 'w', encoding='utf-8') as file:
|
||||||
|
json.dump(train_log, file, indent=2)
|
||||||
|
with open(checkpoint_dir / "training_prompt.json", 'w', encoding='utf-8') as file:
|
||||||
|
json.dump(train_template, file, indent=2)
|
||||||
|
|
||||||
# Fix training for mixed precision models
|
# Fix training for mixed precision models
|
||||||
for param in shared.model.parameters():
|
for param in shared.model.parameters():
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
param.data = param.data.float()
|
param.data = param.data.float()
|
||||||
|
|
||||||
|
lora_model.config.use_cache = False
|
||||||
|
|
||||||
|
def collate_fn(batch):
|
||||||
|
max_len = max(len(item['input_ids']) for item in batch)
|
||||||
|
input_ids, labels, attention_mask = [], [], []
|
||||||
|
for item in batch:
|
||||||
|
pad_len = max_len - len(item['input_ids'])
|
||||||
|
input_ids.append(item['input_ids'] + [shared.tokenizer.pad_token_id] * pad_len)
|
||||||
|
labels.append(item['labels'] + [-100] * pad_len)
|
||||||
|
attention_mask.append(item['attention_mask'] + [0] * pad_len)
|
||||||
|
return {
|
||||||
|
'input_ids': torch.tensor(input_ids),
|
||||||
|
'labels': torch.tensor(labels),
|
||||||
|
'attention_mask': torch.tensor(attention_mask),
|
||||||
|
}
|
||||||
|
|
||||||
trainer = transformers.Trainer(
|
trainer = transformers.Trainer(
|
||||||
model=lora_model,
|
model=lora_model,
|
||||||
train_dataset=train_data,
|
train_dataset=train_data,
|
||||||
eval_dataset=eval_data,
|
eval_dataset=eval_data,
|
||||||
args=transformers.TrainingArguments(
|
args=transformers.TrainingArguments(
|
||||||
report_to=report_to if report_to != "None" else None,
|
report_to=report_to if report_to != "None" else "none",
|
||||||
per_device_train_batch_size=micro_batch_size,
|
per_device_train_batch_size=micro_batch_size,
|
||||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
warmup_steps=math.ceil(warmup_steps / gradient_accumulation_steps),
|
warmup_steps=math.ceil(warmup_steps / gradient_accumulation_steps),
|
||||||
|
|
@ -610,31 +698,27 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||||
fp16=False if shared.args.cpu or shared.args.bf16 else True,
|
fp16=False if shared.args.cpu or shared.args.bf16 else True,
|
||||||
bf16=shared.args.bf16,
|
bf16=shared.args.bf16,
|
||||||
optim=optimizer,
|
optim=optimizer,
|
||||||
logging_steps=2 if stop_at_loss > 0 else 5,
|
logging_steps=1,
|
||||||
eval_strategy="steps" if eval_data is not None else "no",
|
eval_strategy="steps" if eval_data is not None else "no",
|
||||||
eval_steps=math.ceil(eval_steps / gradient_accumulation_steps) if eval_data is not None else None,
|
eval_steps=math.ceil(eval_steps / gradient_accumulation_steps) if eval_data is not None else None,
|
||||||
save_strategy="steps" if eval_data is not None else "no",
|
save_strategy="steps" if save_steps > 0 or eval_data is not None else "no",
|
||||||
|
save_steps=actual_save_steps if save_steps > 0 else None,
|
||||||
output_dir=lora_file_path,
|
output_dir=lora_file_path,
|
||||||
lr_scheduler_type=lr_scheduler_type,
|
lr_scheduler_type=lr_scheduler_type,
|
||||||
load_best_model_at_end=eval_data is not None,
|
load_best_model_at_end=eval_data is not None,
|
||||||
# TODO: Enable multi-device support
|
# TODO: Enable multi-device support
|
||||||
ddp_find_unused_parameters=None,
|
ddp_find_unused_parameters=None,
|
||||||
no_cuda=shared.args.cpu,
|
use_cpu=shared.args.cpu,
|
||||||
# use_ipex=True if is_torch_xpu_available() and not shared.args.cpu else False
|
remove_unused_columns=False,
|
||||||
),
|
),
|
||||||
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
|
data_collator=collate_fn,
|
||||||
callbacks=list([Callbacks()])
|
callbacks=[Callbacks()]
|
||||||
)
|
)
|
||||||
|
|
||||||
lora_model.config.use_cache = False
|
|
||||||
|
|
||||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
|
||||||
lora_model = torch.compile(lora_model)
|
|
||||||
|
|
||||||
# == Save parameters for reuse ==
|
# == Save parameters for reuse ==
|
||||||
with open(f"{lora_file_path}/training_parameters.json", 'w', encoding='utf-8') as file:
|
with open(f"{lora_file_path}/training_parameters.json", 'w', encoding='utf-8') as file:
|
||||||
vars = locals()
|
local_vars = locals()
|
||||||
json.dump({x: vars[x] for x in PARAMETERS}, file, indent=2)
|
json.dump({x: local_vars[x] for x in PARAMETERS}, file, indent=2)
|
||||||
|
|
||||||
# == Save training prompt ==
|
# == Save training prompt ==
|
||||||
with open(f"{lora_file_path}/training_prompt.json", 'w', encoding='utf-8') as file:
|
with open(f"{lora_file_path}/training_prompt.json", 'w', encoding='utf-8') as file:
|
||||||
|
|
@ -646,9 +730,12 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||||
|
|
||||||
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
|
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
|
||||||
|
|
||||||
projections_string = ", ".join([projection.replace("_proj", "") for projection in list_target_modules(model_id)])
|
if target_modules == "all-linear":
|
||||||
|
projections_string = "all-linear"
|
||||||
|
else:
|
||||||
|
projections_string = ", ".join([projection.replace("_proj", "") for projection in target_modules])
|
||||||
|
|
||||||
print(f"Training '{model_id}' model using ({projections_string}) projections")
|
print(f"Training '{model_type}' model using ({projections_string}) projections")
|
||||||
|
|
||||||
if lora_all_param > 0:
|
if lora_all_param > 0:
|
||||||
print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})")
|
print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})")
|
||||||
|
|
@ -676,23 +763,31 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||||
decoded_entries.append({"value": decoded_text})
|
decoded_entries.append({"value": decoded_text})
|
||||||
|
|
||||||
# Write the log file
|
# Write the log file
|
||||||
Path('user_data/logs').mkdir(exist_ok=True)
|
(shared.user_data_dir / 'logs').mkdir(exist_ok=True)
|
||||||
with open(Path('user_data/logs/train_dataset_sample.json'), 'w') as json_file:
|
with open(shared.user_data_dir / 'logs' / 'train_dataset_sample.json', 'w') as json_file:
|
||||||
json.dump(decoded_entries, json_file, indent=4)
|
json.dump(decoded_entries, json_file, indent=4)
|
||||||
|
|
||||||
logger.info("Log file 'train_dataset_sample.json' created in the 'user_data/logs' directory.")
|
logger.info(f"Log file 'train_dataset_sample.json' created in the '{shared.user_data_dir}/logs' directory.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to create log file due to error: {e}")
|
logger.error(f"Failed to create log file due to error: {e}")
|
||||||
|
|
||||||
|
thread_error = None
|
||||||
|
|
||||||
def threaded_run():
|
def threaded_run():
|
||||||
|
nonlocal thread_error
|
||||||
|
try:
|
||||||
log_train_dataset(trainer)
|
log_train_dataset(trainer)
|
||||||
trainer.train()
|
trainer.train(resume_from_checkpoint=resume_checkpoint)
|
||||||
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
|
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
|
||||||
lora_model.save_pretrained(lora_file_path)
|
lora_model.save_pretrained(lora_file_path)
|
||||||
|
tracked.did_save = True
|
||||||
logger.info("LoRA training run is completed and saved.")
|
logger.info("LoRA training run is completed and saved.")
|
||||||
# Save log
|
# Save log
|
||||||
with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file:
|
with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file:
|
||||||
json.dump(train_log, file, indent=2)
|
json.dump(train_log, file, indent=2)
|
||||||
|
except Exception as e:
|
||||||
|
thread_error = e
|
||||||
|
logger.error(f"Training error: {e}")
|
||||||
|
|
||||||
thread = threading.Thread(target=threaded_run)
|
thread = threading.Thread(target=threaded_run)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
@ -721,11 +816,20 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||||
|
|
||||||
yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
|
yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
|
||||||
|
|
||||||
|
# Check for errors from the training thread
|
||||||
|
if thread_error is not None:
|
||||||
|
yield f"Training failed: {thread_error}"
|
||||||
|
return
|
||||||
|
|
||||||
# Saving in the train thread might fail if an error occurs, so save here if so.
|
# Saving in the train thread might fail if an error occurs, so save here if so.
|
||||||
if not tracked.did_save:
|
if not tracked.did_save:
|
||||||
logger.info("Training complete, saving")
|
logger.info("Training complete, saving")
|
||||||
lora_model.save_pretrained(lora_file_path)
|
lora_model.save_pretrained(lora_file_path)
|
||||||
|
|
||||||
|
# Restore the original chat_template if we changed it for training
|
||||||
|
if shared.tokenizer is not None and hasattr(shared.tokenizer, 'chat_template'):
|
||||||
|
shared.tokenizer.chat_template = original_chat_template
|
||||||
|
|
||||||
if WANT_INTERRUPT:
|
if WANT_INTERRUPT:
|
||||||
logger.info("Training interrupted.")
|
logger.info("Training interrupted.")
|
||||||
yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`."
|
yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`."
|
||||||
|
|
@ -734,29 +838,6 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||||
yield f"Done! LoRA saved to `{lora_file_path}`.\n\nBefore testing your new LoRA, make sure to first reload the model, as it is currently dirty from training."
|
yield f"Done! LoRA saved to `{lora_file_path}`.\n\nBefore testing your new LoRA, make sure to first reload the model, as it is currently dirty from training."
|
||||||
|
|
||||||
|
|
||||||
def split_chunks(arr, size, step):
|
|
||||||
for i in range(0, len(arr), step):
|
|
||||||
yield arr[i:i + size]
|
|
||||||
|
|
||||||
|
|
||||||
def cut_chunk_for_newline(chunk: str, max_length: int):
|
|
||||||
if '\n' not in chunk:
|
|
||||||
return chunk
|
|
||||||
|
|
||||||
first_newline = chunk.index('\n')
|
|
||||||
if first_newline < max_length:
|
|
||||||
chunk = chunk[first_newline + 1:]
|
|
||||||
|
|
||||||
if '\n' not in chunk:
|
|
||||||
return chunk
|
|
||||||
|
|
||||||
last_newline = chunk.rindex('\n')
|
|
||||||
if len(chunk) - last_newline < max_length:
|
|
||||||
chunk = chunk[:last_newline]
|
|
||||||
|
|
||||||
return chunk
|
|
||||||
|
|
||||||
|
|
||||||
def format_time(seconds: float):
|
def format_time(seconds: float):
|
||||||
if seconds < 120:
|
if seconds < 120:
|
||||||
return f"`{seconds:.0f}` seconds"
|
return f"`{seconds:.0f}` seconds"
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import os
|
|
||||||
import pprint
|
import pprint
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
@ -6,11 +5,7 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import infer_auto_device_map, init_empty_weights
|
from accelerate import infer_auto_device_map, init_empty_weights
|
||||||
from accelerate.utils import (
|
from accelerate.utils import is_xpu_available
|
||||||
is_ccl_available,
|
|
||||||
is_npu_available,
|
|
||||||
is_xpu_available
|
|
||||||
)
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
|
|
@ -28,31 +23,6 @@ from modules.torch_utils import get_device
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
local_rank = None
|
|
||||||
if shared.args.deepspeed:
|
|
||||||
import deepspeed
|
|
||||||
from transformers.integrations.deepspeed import (
|
|
||||||
HfDeepSpeedConfig,
|
|
||||||
is_deepspeed_zero3_enabled
|
|
||||||
)
|
|
||||||
|
|
||||||
from modules.deepspeed_parameters import generate_ds_config
|
|
||||||
|
|
||||||
# Distributed setup
|
|
||||||
local_rank = shared.args.local_rank if shared.args.local_rank is not None else int(os.getenv("LOCAL_RANK", "0"))
|
|
||||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
|
||||||
if is_xpu_available() and is_ccl_available():
|
|
||||||
torch.xpu.set_device(local_rank)
|
|
||||||
deepspeed.init_distributed(backend="ccl")
|
|
||||||
elif is_npu_available():
|
|
||||||
torch.npu.set_device(local_rank)
|
|
||||||
deepspeed.init_distributed(dist_backend="hccl")
|
|
||||||
else:
|
|
||||||
torch.cuda.set_device(local_rank)
|
|
||||||
deepspeed.init_distributed()
|
|
||||||
ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
|
|
||||||
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
|
|
||||||
|
|
||||||
|
|
||||||
class _StopEverythingStoppingCriteria(transformers.StoppingCriteria):
|
class _StopEverythingStoppingCriteria(transformers.StoppingCriteria):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
@ -95,14 +65,16 @@ class LogprobProcessor(LogitsProcessor):
|
||||||
def __init__(self, logprobs=None):
|
def __init__(self, logprobs=None):
|
||||||
self.logprobs = logprobs
|
self.logprobs = logprobs
|
||||||
self.token_alternatives = {}
|
self.token_alternatives = {}
|
||||||
|
self.token_alternatives_history = []
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
if self.logprobs is not None: # 0-5
|
if self.logprobs is not None: # 0-5
|
||||||
log_e_probabilities = F.log_softmax(logits, dim=1)
|
log_e_probabilities = F.log_softmax(logits, dim=1)
|
||||||
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
|
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs)
|
||||||
top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]]
|
top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]]
|
||||||
top_probs = [float(x) for x in top_values[0]]
|
top_probs = [float(x) for x in top_values[0]]
|
||||||
self.token_alternatives = dict(zip(top_tokens, top_probs))
|
self.token_alternatives = dict(zip(top_tokens, top_probs))
|
||||||
|
self.token_alternatives_history.append(self.token_alternatives)
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
@ -163,10 +135,7 @@ def load_model_HF(model_name):
|
||||||
shared.args.load_in_8bit,
|
shared.args.load_in_8bit,
|
||||||
shared.args.load_in_4bit,
|
shared.args.load_in_4bit,
|
||||||
shared.args.disk,
|
shared.args.disk,
|
||||||
shared.args.deepspeed,
|
|
||||||
shared.args.cpu_memory is not None,
|
shared.args.cpu_memory is not None,
|
||||||
shared.args.compress_pos_emb > 1,
|
|
||||||
shared.args.alpha_value > 1,
|
|
||||||
])
|
])
|
||||||
|
|
||||||
# Load the model without any special settings
|
# Load the model without any special settings
|
||||||
|
|
@ -183,25 +152,6 @@ def load_model_HF(model_name):
|
||||||
if device:
|
if device:
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
# DeepSpeed ZeRO-3
|
|
||||||
elif shared.args.deepspeed:
|
|
||||||
model = LoaderClass.from_pretrained(
|
|
||||||
path_to_model,
|
|
||||||
torch_dtype=params['torch_dtype'],
|
|
||||||
trust_remote_code=params.get('trust_remote_code')
|
|
||||||
)
|
|
||||||
|
|
||||||
model = deepspeed.initialize(
|
|
||||||
model=model,
|
|
||||||
config_params=ds_config,
|
|
||||||
model_parameters=None,
|
|
||||||
optimizer=None,
|
|
||||||
lr_scheduler=None
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
model.module.eval() # Inference
|
|
||||||
logger.info(f'DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}')
|
|
||||||
|
|
||||||
# Load with quantization and/or offloading
|
# Load with quantization and/or offloading
|
||||||
else:
|
else:
|
||||||
if not any((shared.args.cpu, torch.cuda.is_available(), is_xpu_available(), torch.backends.mps.is_available())):
|
if not any((shared.args.cpu, torch.cuda.is_available(), is_xpu_available(), torch.backends.mps.is_available())):
|
||||||
|
|
@ -248,11 +198,6 @@ def load_model_HF(model_name):
|
||||||
if shared.args.disk:
|
if shared.args.disk:
|
||||||
params['offload_folder'] = str(Path(shared.args.disk_cache_dir))
|
params['offload_folder'] = str(Path(shared.args.disk_cache_dir))
|
||||||
|
|
||||||
if shared.args.compress_pos_emb > 1:
|
|
||||||
params['rope_scaling'] = {'type': 'linear', 'factor': shared.args.compress_pos_emb}
|
|
||||||
elif shared.args.alpha_value > 1:
|
|
||||||
params['rope_scaling'] = {'type': 'dynamic', 'factor': shared.args.alpha_value}
|
|
||||||
|
|
||||||
logger.info("TRANSFORMERS_PARAMS=")
|
logger.info("TRANSFORMERS_PARAMS=")
|
||||||
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params)
|
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params)
|
||||||
print()
|
print()
|
||||||
|
|
|
||||||
|
|
@ -113,65 +113,15 @@ if not shared.args.old_colors:
|
||||||
block_radius='0',
|
block_radius='0',
|
||||||
)
|
)
|
||||||
|
|
||||||
if Path("user_data/notification.mp3").exists():
|
if (shared.user_data_dir / "notification.mp3").exists():
|
||||||
audio_notification_js = "document.querySelector('#audio_notification audio')?.play();"
|
audio_notification_js = "document.querySelector('#audio_notification audio')?.play();"
|
||||||
else:
|
else:
|
||||||
audio_notification_js = ""
|
audio_notification_js = ""
|
||||||
|
|
||||||
|
|
||||||
def list_model_elements():
|
def list_model_elements():
|
||||||
elements = [
|
from modules.loaders import list_model_elements
|
||||||
'filter_by_loader',
|
return list_model_elements()
|
||||||
'loader',
|
|
||||||
'cpu_memory',
|
|
||||||
'gpu_layers',
|
|
||||||
'cpu_moe',
|
|
||||||
'threads',
|
|
||||||
'threads_batch',
|
|
||||||
'batch_size',
|
|
||||||
'ubatch_size',
|
|
||||||
'ctx_size',
|
|
||||||
'cache_type',
|
|
||||||
'tensor_split',
|
|
||||||
'extra_flags',
|
|
||||||
'streaming_llm',
|
|
||||||
'gpu_split',
|
|
||||||
'alpha_value',
|
|
||||||
'rope_freq_base',
|
|
||||||
'compress_pos_emb',
|
|
||||||
'compute_dtype',
|
|
||||||
'quant_type',
|
|
||||||
'num_experts_per_token',
|
|
||||||
'load_in_8bit',
|
|
||||||
'load_in_4bit',
|
|
||||||
'attn_implementation',
|
|
||||||
'cpu',
|
|
||||||
'disk',
|
|
||||||
'row_split',
|
|
||||||
'no_kv_offload',
|
|
||||||
'no_mmap',
|
|
||||||
'mlock',
|
|
||||||
'numa',
|
|
||||||
'use_double_quant',
|
|
||||||
'bf16',
|
|
||||||
'autosplit',
|
|
||||||
'enable_tp',
|
|
||||||
'tp_backend',
|
|
||||||
'no_flash_attn',
|
|
||||||
'no_xformers',
|
|
||||||
'no_sdpa',
|
|
||||||
'cfg_cache',
|
|
||||||
'cpp_runner',
|
|
||||||
'no_use_fast',
|
|
||||||
'model_draft',
|
|
||||||
'draft_max',
|
|
||||||
'gpu_layers_draft',
|
|
||||||
'device_draft',
|
|
||||||
'ctx_size_draft',
|
|
||||||
'mmproj',
|
|
||||||
]
|
|
||||||
|
|
||||||
return elements
|
|
||||||
|
|
||||||
|
|
||||||
def list_interface_input_elements():
|
def list_interface_input_elements():
|
||||||
|
|
@ -193,6 +143,8 @@ def list_interface_input_elements():
|
||||||
'tfs',
|
'tfs',
|
||||||
'top_a',
|
'top_a',
|
||||||
'top_n_sigma',
|
'top_n_sigma',
|
||||||
|
'adaptive_target',
|
||||||
|
'adaptive_decay',
|
||||||
'dry_multiplier',
|
'dry_multiplier',
|
||||||
'dry_allowed_length',
|
'dry_allowed_length',
|
||||||
'dry_base',
|
'dry_base',
|
||||||
|
|
@ -247,10 +199,12 @@ def list_interface_input_elements():
|
||||||
'unique_id',
|
'unique_id',
|
||||||
'textbox',
|
'textbox',
|
||||||
'start_with',
|
'start_with',
|
||||||
|
'selected_tools',
|
||||||
'mode',
|
'mode',
|
||||||
'chat_style',
|
'chat_style',
|
||||||
'chat-instruct_command',
|
'chat-instruct_command',
|
||||||
'character_menu',
|
'character_menu',
|
||||||
|
'user_menu',
|
||||||
'name2',
|
'name2',
|
||||||
'context',
|
'context',
|
||||||
'greeting',
|
'greeting',
|
||||||
|
|
@ -350,10 +304,16 @@ def save_settings(state, preset, extensions_list, show_controls, theme_state, ma
|
||||||
if k in shared.settings and k not in exclude:
|
if k in shared.settings and k not in exclude:
|
||||||
output[k] = state[k]
|
output[k] = state[k]
|
||||||
|
|
||||||
|
if preset:
|
||||||
output['preset'] = preset
|
output['preset'] = preset
|
||||||
output['prompt-notebook'] = state['prompt_menu-default'] if state['show_two_notebook_columns'] else state['prompt_menu-notebook']
|
output['prompt-notebook'] = state['prompt_menu-default'] if state['show_two_notebook_columns'] else state['prompt_menu-notebook']
|
||||||
|
if state.get('character_menu'):
|
||||||
output['character'] = state['character_menu']
|
output['character'] = state['character_menu']
|
||||||
|
if state.get('user_menu'):
|
||||||
|
output['user'] = state['user_menu']
|
||||||
output['seed'] = int(output['seed'])
|
output['seed'] = int(output['seed'])
|
||||||
|
output['custom_stopping_strings'] = output.get('custom_stopping_strings') or ''
|
||||||
|
output['custom_token_bans'] = output.get('custom_token_bans') or ''
|
||||||
output['show_controls'] = show_controls
|
output['show_controls'] = show_controls
|
||||||
output['dark_theme'] = True if theme_state == 'dark' else False
|
output['dark_theme'] = True if theme_state == 'dark' else False
|
||||||
output.pop('instruction_template_str')
|
output.pop('instruction_template_str')
|
||||||
|
|
@ -377,7 +337,7 @@ def save_settings(state, preset, extensions_list, show_controls, theme_state, ma
|
||||||
output[_id] = params[param]
|
output[_id] = params[param]
|
||||||
else:
|
else:
|
||||||
# Preserve existing extensions and extension parameters during autosave
|
# Preserve existing extensions and extension parameters during autosave
|
||||||
settings_path = Path('user_data') / 'settings.yaml'
|
settings_path = shared.user_data_dir / 'settings.yaml'
|
||||||
if settings_path.exists():
|
if settings_path.exists():
|
||||||
try:
|
try:
|
||||||
with open(settings_path, 'r', encoding='utf-8') as f:
|
with open(settings_path, 'r', encoding='utf-8') as f:
|
||||||
|
|
@ -432,7 +392,7 @@ def _perform_debounced_save():
|
||||||
try:
|
try:
|
||||||
if _last_interface_state is not None:
|
if _last_interface_state is not None:
|
||||||
contents = save_settings(_last_interface_state, _last_preset, _last_extensions, _last_show_controls, _last_theme_state, manual_save=False)
|
contents = save_settings(_last_interface_state, _last_preset, _last_extensions, _last_show_controls, _last_theme_state, manual_save=False)
|
||||||
settings_path = Path('user_data') / 'settings.yaml'
|
settings_path = shared.user_data_dir / 'settings.yaml'
|
||||||
settings_path.parent.mkdir(exist_ok=True)
|
settings_path.parent.mkdir(exist_ok=True)
|
||||||
with open(settings_path, 'w', encoding='utf-8') as f:
|
with open(settings_path, 'w', encoding='utf-8') as f:
|
||||||
f.write(contents)
|
f.write(contents)
|
||||||
|
|
@ -457,6 +417,7 @@ def setup_auto_save():
|
||||||
'chat_style',
|
'chat_style',
|
||||||
'chat-instruct_command',
|
'chat-instruct_command',
|
||||||
'character_menu',
|
'character_menu',
|
||||||
|
'user_menu',
|
||||||
'name1',
|
'name1',
|
||||||
'name2',
|
'name2',
|
||||||
'context',
|
'context',
|
||||||
|
|
@ -464,6 +425,7 @@ def setup_auto_save():
|
||||||
'user_bio',
|
'user_bio',
|
||||||
'custom_system_message',
|
'custom_system_message',
|
||||||
'chat_template_str',
|
'chat_template_str',
|
||||||
|
'selected_tools',
|
||||||
|
|
||||||
# Parameters tab (ui_parameters.py) - Generation parameters
|
# Parameters tab (ui_parameters.py) - Generation parameters
|
||||||
'preset_menu',
|
'preset_menu',
|
||||||
|
|
@ -484,6 +446,8 @@ def setup_auto_save():
|
||||||
'tfs',
|
'tfs',
|
||||||
'top_a',
|
'top_a',
|
||||||
'top_n_sigma',
|
'top_n_sigma',
|
||||||
|
'adaptive_target',
|
||||||
|
'adaptive_decay',
|
||||||
'dry_multiplier',
|
'dry_multiplier',
|
||||||
'dry_allowed_length',
|
'dry_allowed_length',
|
||||||
'dry_base',
|
'dry_base',
|
||||||
|
|
@ -512,7 +476,6 @@ def setup_auto_save():
|
||||||
'skip_special_tokens',
|
'skip_special_tokens',
|
||||||
'stream',
|
'stream',
|
||||||
'static_cache',
|
'static_cache',
|
||||||
'truncation_length',
|
|
||||||
'seed',
|
'seed',
|
||||||
'sampler_priority',
|
'sampler_priority',
|
||||||
'custom_stopping_strings',
|
'custom_stopping_strings',
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ def create_ui():
|
||||||
|
|
||||||
shared.gradio['Chat input'] = gr.State()
|
shared.gradio['Chat input'] = gr.State()
|
||||||
shared.gradio['history'] = gr.State({'internal': [], 'visible': [], 'metadata': {}})
|
shared.gradio['history'] = gr.State({'internal': [], 'visible': [], 'metadata': {}})
|
||||||
shared.gradio['display'] = gr.JSON(value={}, visible=False) # Hidden buffer
|
shared.gradio['display'] = gr.Headless(value={})
|
||||||
|
|
||||||
with gr.Tab('Chat', elem_id='chat-tab'):
|
with gr.Tab('Chat', elem_id='chat-tab'):
|
||||||
with gr.Row(elem_id='past-chats-row', elem_classes=['pretty_scrollbar']):
|
with gr.Row(elem_id='past-chats-row', elem_classes=['pretty_scrollbar']):
|
||||||
|
|
@ -28,7 +28,8 @@ def create_ui():
|
||||||
shared.gradio['branch_chat'] = gr.Button('Branch', elem_classes=['refresh-button', 'refresh-button-medium'], elem_id='Branch', interactive=not mu)
|
shared.gradio['branch_chat'] = gr.Button('Branch', elem_classes=['refresh-button', 'refresh-button-medium'], elem_id='Branch', interactive=not mu)
|
||||||
shared.gradio['rename_chat'] = gr.Button('Rename', elem_classes=['refresh-button', 'refresh-button-medium'], interactive=not mu)
|
shared.gradio['rename_chat'] = gr.Button('Rename', elem_classes=['refresh-button', 'refresh-button-medium'], interactive=not mu)
|
||||||
shared.gradio['delete_chat'] = gr.Button('🗑️', visible=False, elem_classes='refresh-button', interactive=not mu, elem_id='delete_chat')
|
shared.gradio['delete_chat'] = gr.Button('🗑️', visible=False, elem_classes='refresh-button', interactive=not mu, elem_id='delete_chat')
|
||||||
shared.gradio['Start new chat'] = gr.Button('New chat', elem_classes=['refresh-button', 'refresh-button-medium', 'focus-on-chat-input'])
|
shared.gradio['Start new chat'] = gr.Button('New chat', elem_classes=['refresh-button', 'refresh-button-medium', 'focus-on-chat-input'], elem_id='new-chat-btn')
|
||||||
|
shared.gradio['Start incognito chat'] = gr.Button('Incognito chat', visible=False, elem_id='incognito-chat-btn')
|
||||||
shared.gradio['branch_index'] = gr.Number(value=-1, precision=0, visible=False, elem_id="Branch-index", interactive=True)
|
shared.gradio['branch_index'] = gr.Number(value=-1, precision=0, visible=False, elem_id="Branch-index", interactive=True)
|
||||||
|
|
||||||
shared.gradio['search_chat'] = gr.Textbox(placeholder='Search chats...', max_lines=1, elem_id='search_chat')
|
shared.gradio['search_chat'] = gr.Textbox(placeholder='Search chats...', max_lines=1, elem_id='search_chat')
|
||||||
|
|
@ -91,6 +92,21 @@ def create_ui():
|
||||||
|
|
||||||
gr.HTML("<div class='sidebar-vertical-separator'></div>")
|
gr.HTML("<div class='sidebar-vertical-separator'></div>")
|
||||||
|
|
||||||
|
from modules.tool_use import get_available_tools
|
||||||
|
shared.gradio['selected_tools'] = gr.CheckboxGroup(choices=get_available_tools(), value=shared.settings.get('selected_tools', []), label='Tools', info='Functions the model can call during generation.', elem_id='tools-group')
|
||||||
|
shared.gradio['tools_refresh'] = gr.Button('Refresh list', elem_id='tools-refresh-btn', visible=False)
|
||||||
|
shared.gradio['tools_refresh'].click(fn=lambda: gr.update(choices=get_available_tools()), inputs=[], outputs=[shared.gradio['selected_tools']])
|
||||||
|
|
||||||
|
def sync_web_tools(selected):
|
||||||
|
if 'web_search' in selected and 'fetch_webpage' not in selected and 'fetch_webpage' in get_available_tools():
|
||||||
|
selected.append('fetch_webpage')
|
||||||
|
|
||||||
|
return gr.update(value=selected)
|
||||||
|
|
||||||
|
shared.gradio['selected_tools'].change(fn=sync_web_tools, inputs=[shared.gradio['selected_tools']], outputs=[shared.gradio['selected_tools']], show_progress=False)
|
||||||
|
|
||||||
|
gr.HTML("<div class='sidebar-vertical-separator'></div>")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['mode'] = gr.Radio(choices=['instruct', 'chat-instruct', 'chat'], value=None, label='Mode', info='In instruct and chat-instruct modes, the template under Parameters > Instruction template is used.', elem_id='chat-mode')
|
shared.gradio['mode'] = gr.Radio(choices=['instruct', 'chat-instruct', 'chat'], value=None, label='Mode', info='In instruct and chat-instruct modes, the template under Parameters > Instruction template is used.', elem_id='chat-mode')
|
||||||
|
|
||||||
|
|
@ -137,6 +153,12 @@ def create_character_settings_ui():
|
||||||
shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=5, label='Greeting', elem_classes=['add_scrollbar'], elem_id="character-greeting")
|
shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=5, label='Greeting', elem_classes=['add_scrollbar'], elem_id="character-greeting")
|
||||||
|
|
||||||
with gr.Tab("User"):
|
with gr.Tab("User"):
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['user_menu'] = gr.Dropdown(value=shared.settings['user'], choices=utils.get_available_users(), label='User', elem_id='user-menu', info='Select a user profile.', elem_classes='slim-dropdown')
|
||||||
|
ui.create_refresh_button(shared.gradio['user_menu'], lambda: None, lambda: {'choices': utils.get_available_users()}, 'refresh-button', interactive=not mu)
|
||||||
|
shared.gradio['save_user'] = gr.Button('💾', elem_classes='refresh-button', elem_id="save-user", interactive=not mu)
|
||||||
|
shared.gradio['delete_user'] = gr.Button('🗑️', elem_classes='refresh-button', interactive=not mu)
|
||||||
|
|
||||||
shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Name')
|
shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Name')
|
||||||
shared.gradio['user_bio'] = gr.Textbox(value=shared.settings['user_bio'], lines=10, label='Description', info='Here you can optionally write a description of yourself.', placeholder='{{user}}\'s personality: ...', elem_classes=['add_scrollbar'], elem_id="user-description")
|
shared.gradio['user_bio'] = gr.Textbox(value=shared.settings['user_bio'], lines=10, label='Description', info='Here you can optionally write a description of yourself.', placeholder='{{user}}\'s personality: ...', elem_classes=['add_scrollbar'], elem_id="user-description")
|
||||||
|
|
||||||
|
|
@ -169,7 +191,7 @@ def create_character_settings_ui():
|
||||||
|
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
shared.gradio['character_picture'] = gr.Image(label='Character picture', type='filepath', interactive=not mu)
|
shared.gradio['character_picture'] = gr.Image(label='Character picture', type='filepath', interactive=not mu)
|
||||||
shared.gradio['your_picture'] = gr.Image(label='Your picture', type='filepath', value=Image.open(Path('user_data/cache/pfp_me.png')) if Path('user_data/cache/pfp_me.png').exists() else None, interactive=not mu)
|
shared.gradio['your_picture'] = gr.Image(label='Your picture', type='filepath', value=Image.open(shared.user_data_dir / 'cache' / 'pfp_me.png') if (shared.user_data_dir / 'cache' / 'pfp_me.png').exists() else None, interactive=not mu)
|
||||||
|
|
||||||
|
|
||||||
def create_chat_settings_ui():
|
def create_chat_settings_ui():
|
||||||
|
|
@ -269,6 +291,10 @@ def create_event_handlers():
|
||||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
chat.handle_start_new_chat_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False)
|
chat.handle_start_new_chat_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False)
|
||||||
|
|
||||||
|
shared.gradio['Start incognito chat'].click(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
chat.handle_start_incognito_chat_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False)
|
||||||
|
|
||||||
shared.gradio['delete_chat-confirm'].click(
|
shared.gradio['delete_chat-confirm'].click(
|
||||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
chat.handle_delete_chat_confirm_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False)
|
chat.handle_delete_chat_confirm_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False)
|
||||||
|
|
@ -324,13 +350,13 @@ def create_event_handlers():
|
||||||
shared.gradio['load_template'].click(chat.handle_load_template_click, gradio('instruction_template'), gradio('instruction_template_str', 'instruction_template'), show_progress=False)
|
shared.gradio['load_template'].click(chat.handle_load_template_click, gradio('instruction_template'), gradio('instruction_template_str', 'instruction_template'), show_progress=False)
|
||||||
shared.gradio['save_template'].click(
|
shared.gradio['save_template'].click(
|
||||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
chat.handle_save_template_click, gradio('instruction_template_str'), gradio('save_filename', 'save_root', 'save_contents', 'file_saver'), show_progress=False)
|
chat.handle_save_template_click, gradio('instruction_template_str'), gradio('save_filename', 'save_root', 'save_contents', 'save_root_state', 'file_saver'), show_progress=False)
|
||||||
|
|
||||||
shared.gradio['restore_character'].click(
|
shared.gradio['restore_character'].click(
|
||||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
chat.restore_character_for_ui, gradio('interface_state'), gradio('interface_state', 'name2', 'context', 'greeting', 'character_picture'), show_progress=False)
|
chat.restore_character_for_ui, gradio('interface_state'), gradio('interface_state', 'name2', 'context', 'greeting', 'character_picture'), show_progress=False)
|
||||||
|
|
||||||
shared.gradio['delete_template'].click(chat.handle_delete_template_click, gradio('instruction_template'), gradio('delete_filename', 'delete_root', 'file_deleter'), show_progress=False)
|
shared.gradio['delete_template'].click(chat.handle_delete_template_click, gradio('instruction_template'), gradio('delete_filename', 'delete_root', 'delete_root_state', 'file_deleter'), show_progress=False)
|
||||||
shared.gradio['save_chat_history'].click(
|
shared.gradio['save_chat_history'].click(
|
||||||
lambda x: json.dumps(x, indent=4), gradio('history'), gradio('temporary_text')).then(
|
lambda x: json.dumps(x, indent=4), gradio('history'), gradio('temporary_text')).then(
|
||||||
None, gradio('temporary_text', 'character_menu', 'mode'), None, js=f'(hist, char, mode) => {{{ui.save_files_js}; saveHistory(hist, char, mode)}}')
|
None, gradio('temporary_text', 'character_menu', 'mode'), None, js=f'(hist, char, mode) => {{{ui.save_files_js}; saveHistory(hist, char, mode)}}')
|
||||||
|
|
@ -372,3 +398,11 @@ def create_event_handlers():
|
||||||
gradio('enable_web_search'),
|
gradio('enable_web_search'),
|
||||||
gradio('web_search_row')
|
gradio('web_search_row')
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# User menu event handlers
|
||||||
|
shared.gradio['user_menu'].change(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
chat.handle_user_menu_change, gradio('interface_state'), gradio('name1', 'user_bio', 'your_picture'), show_progress=False)
|
||||||
|
|
||||||
|
shared.gradio['save_user'].click(chat.handle_save_user_click, gradio('name1'), gradio('save_user_filename', 'user_saver'), show_progress=False)
|
||||||
|
shared.gradio['delete_user'].click(lambda: gr.update(visible=True), None, gradio('user_deleter'), show_progress=False)
|
||||||
|
|
|
||||||
|
|
@ -159,7 +159,7 @@ def handle_new_prompt():
|
||||||
new_name = utils.current_time()
|
new_name = utils.current_time()
|
||||||
|
|
||||||
# Create the new prompt file
|
# Create the new prompt file
|
||||||
prompt_path = Path("user_data/logs/notebook") / f"{new_name}.txt"
|
prompt_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt"
|
||||||
prompt_path.parent.mkdir(parents=True, exist_ok=True)
|
prompt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
prompt_path.write_text("In this story,", encoding='utf-8')
|
prompt_path.write_text("In this story,", encoding='utf-8')
|
||||||
|
|
||||||
|
|
@ -170,15 +170,15 @@ def handle_delete_prompt_confirm_default(prompt_name):
|
||||||
available_prompts = utils.get_available_prompts()
|
available_prompts = utils.get_available_prompts()
|
||||||
current_index = available_prompts.index(prompt_name) if prompt_name in available_prompts else 0
|
current_index = available_prompts.index(prompt_name) if prompt_name in available_prompts else 0
|
||||||
|
|
||||||
(Path("user_data/logs/notebook") / f"{prompt_name}.txt").unlink(missing_ok=True)
|
(shared.user_data_dir / "logs" / "notebook" / f"{prompt_name}.txt").unlink(missing_ok=True)
|
||||||
available_prompts = utils.get_available_prompts()
|
available_prompts = utils.get_available_prompts()
|
||||||
|
|
||||||
if available_prompts:
|
if available_prompts:
|
||||||
new_value = available_prompts[min(current_index, len(available_prompts) - 1)]
|
new_value = available_prompts[min(current_index, len(available_prompts) - 1)]
|
||||||
else:
|
else:
|
||||||
new_value = utils.current_time()
|
new_value = utils.current_time()
|
||||||
Path("user_data/logs/notebook").mkdir(parents=True, exist_ok=True)
|
(shared.user_data_dir / "logs" / "notebook").mkdir(parents=True, exist_ok=True)
|
||||||
(Path("user_data/logs/notebook") / f"{new_value}.txt").write_text("In this story,")
|
(shared.user_data_dir / "logs" / "notebook" / f"{new_value}.txt").write_text("In this story,")
|
||||||
available_prompts = [new_value]
|
available_prompts = [new_value]
|
||||||
|
|
||||||
return [
|
return [
|
||||||
|
|
@ -199,8 +199,8 @@ def handle_rename_prompt_click_default(current_name):
|
||||||
|
|
||||||
|
|
||||||
def handle_rename_prompt_confirm_default(new_name, current_name):
|
def handle_rename_prompt_confirm_default(new_name, current_name):
|
||||||
old_path = Path("user_data/logs/notebook") / f"{current_name}.txt"
|
old_path = shared.user_data_dir / "logs" / "notebook" / f"{current_name}.txt"
|
||||||
new_path = Path("user_data/logs/notebook") / f"{new_name}.txt"
|
new_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt"
|
||||||
|
|
||||||
if old_path.exists() and not new_path.exists():
|
if old_path.exists() and not new_path.exists():
|
||||||
old_path.rename(new_path)
|
old_path.rename(new_path)
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,18 @@ import traceback
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import chat, presets, shared, ui, utils
|
from modules import chat, presets, shared, ui, utils
|
||||||
from modules.utils import gradio
|
from modules.utils import gradio, sanitize_filename
|
||||||
|
|
||||||
|
|
||||||
def create_ui():
|
def create_ui():
|
||||||
mu = shared.args.multi_user
|
mu = shared.args.multi_user
|
||||||
|
|
||||||
|
# Server-side per-session root paths for the generic file saver/deleter.
|
||||||
|
# Set by the handler that opens the dialog, read by the confirm handler.
|
||||||
|
# Using gr.State so they are session-scoped and safe for multi-user.
|
||||||
|
shared.gradio['save_root_state'] = gr.State(None)
|
||||||
|
shared.gradio['delete_root_state'] = gr.State(None)
|
||||||
|
|
||||||
# Text file saver
|
# Text file saver
|
||||||
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['file_saver']:
|
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['file_saver']:
|
||||||
shared.gradio['save_filename'] = gr.Textbox(lines=1, label='File name')
|
shared.gradio['save_filename'] = gr.Textbox(lines=1, label='File name')
|
||||||
|
|
@ -28,7 +34,7 @@ def create_ui():
|
||||||
|
|
||||||
# Character saver/deleter
|
# Character saver/deleter
|
||||||
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['character_saver']:
|
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['character_saver']:
|
||||||
shared.gradio['save_character_filename'] = gr.Textbox(lines=1, label='File name', info='The character will be saved to your user_data/characters folder with this base filename.')
|
shared.gradio['save_character_filename'] = gr.Textbox(lines=1, label='File name', info=f'The character will be saved to your {shared.user_data_dir}/characters folder with this base filename.')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['save_character_cancel'] = gr.Button('Cancel', elem_classes="small-button")
|
shared.gradio['save_character_cancel'] = gr.Button('Cancel', elem_classes="small-button")
|
||||||
shared.gradio['save_character_confirm'] = gr.Button('Save', elem_classes="small-button", variant='primary', interactive=not mu)
|
shared.gradio['save_character_confirm'] = gr.Button('Save', elem_classes="small-button", variant='primary', interactive=not mu)
|
||||||
|
|
@ -39,9 +45,22 @@ def create_ui():
|
||||||
shared.gradio['delete_character_cancel'] = gr.Button('Cancel', elem_classes="small-button")
|
shared.gradio['delete_character_cancel'] = gr.Button('Cancel', elem_classes="small-button")
|
||||||
shared.gradio['delete_character_confirm'] = gr.Button('Delete', elem_classes="small-button", variant='stop', interactive=not mu)
|
shared.gradio['delete_character_confirm'] = gr.Button('Delete', elem_classes="small-button", variant='stop', interactive=not mu)
|
||||||
|
|
||||||
|
# User saver/deleter
|
||||||
|
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['user_saver']:
|
||||||
|
shared.gradio['save_user_filename'] = gr.Textbox(lines=1, label='File name', info=f'The user profile will be saved to your {shared.user_data_dir}/users folder with this base filename.')
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['save_user_cancel'] = gr.Button('Cancel', elem_classes="small-button")
|
||||||
|
shared.gradio['save_user_confirm'] = gr.Button('Save', elem_classes="small-button", variant='primary', interactive=not mu)
|
||||||
|
|
||||||
|
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['user_deleter']:
|
||||||
|
gr.Markdown('Confirm the user deletion?')
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['delete_user_cancel'] = gr.Button('Cancel', elem_classes="small-button")
|
||||||
|
shared.gradio['delete_user_confirm'] = gr.Button('Delete', elem_classes="small-button", variant='stop', interactive=not mu)
|
||||||
|
|
||||||
# Preset saver
|
# Preset saver
|
||||||
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['preset_saver']:
|
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['preset_saver']:
|
||||||
shared.gradio['save_preset_filename'] = gr.Textbox(lines=1, label='File name', info='The preset will be saved to your user_data/presets folder with this base filename.')
|
shared.gradio['save_preset_filename'] = gr.Textbox(lines=1, label='File name', info=f'The preset will be saved to your {shared.user_data_dir}/presets folder with this base filename.')
|
||||||
shared.gradio['save_preset_contents'] = gr.Textbox(lines=10, label='File contents')
|
shared.gradio['save_preset_contents'] = gr.Textbox(lines=10, label='File contents')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['save_preset_cancel'] = gr.Button('Cancel', elem_classes="small-button")
|
shared.gradio['save_preset_cancel'] = gr.Button('Cancel', elem_classes="small-button")
|
||||||
|
|
@ -53,13 +72,13 @@ def create_event_handlers():
|
||||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
handle_save_preset_click, gradio('interface_state'), gradio('save_preset_contents', 'save_preset_filename', 'preset_saver'), show_progress=False)
|
handle_save_preset_click, gradio('interface_state'), gradio('save_preset_contents', 'save_preset_filename', 'preset_saver'), show_progress=False)
|
||||||
|
|
||||||
shared.gradio['delete_preset'].click(handle_delete_preset_click, gradio('preset_menu'), gradio('delete_filename', 'delete_root', 'file_deleter'), show_progress=False)
|
shared.gradio['delete_preset'].click(handle_delete_preset_click, gradio('preset_menu'), gradio('delete_filename', 'delete_root', 'delete_root_state', 'file_deleter'), show_progress=False)
|
||||||
shared.gradio['save_grammar'].click(handle_save_grammar_click, gradio('grammar_string'), gradio('save_contents', 'save_filename', 'save_root', 'file_saver'), show_progress=False)
|
shared.gradio['save_grammar'].click(handle_save_grammar_click, gradio('grammar_string'), gradio('save_contents', 'save_filename', 'save_root', 'save_root_state', 'file_saver'), show_progress=False)
|
||||||
shared.gradio['delete_grammar'].click(handle_delete_grammar_click, gradio('grammar_file'), gradio('delete_filename', 'delete_root', 'file_deleter'), show_progress=False)
|
shared.gradio['delete_grammar'].click(handle_delete_grammar_click, gradio('grammar_file'), gradio('delete_filename', 'delete_root', 'delete_root_state', 'file_deleter'), show_progress=False)
|
||||||
|
|
||||||
shared.gradio['save_preset_confirm'].click(handle_save_preset_confirm_click, gradio('save_preset_filename', 'save_preset_contents'), gradio('preset_menu', 'preset_saver'), show_progress=False)
|
shared.gradio['save_preset_confirm'].click(handle_save_preset_confirm_click, gradio('save_preset_filename', 'save_preset_contents'), gradio('preset_menu', 'preset_saver'), show_progress=False)
|
||||||
shared.gradio['save_confirm'].click(handle_save_confirm_click, gradio('save_root', 'save_filename', 'save_contents'), gradio('file_saver'), show_progress=False)
|
shared.gradio['save_confirm'].click(handle_save_confirm_click, gradio('save_root_state', 'save_filename', 'save_contents'), gradio('save_root_state', 'file_saver'), show_progress=False)
|
||||||
shared.gradio['delete_confirm'].click(handle_delete_confirm_click, gradio('delete_root', 'delete_filename'), gradio('file_deleter'), show_progress=False)
|
shared.gradio['delete_confirm'].click(handle_delete_confirm_click, gradio('delete_root_state', 'delete_filename'), gradio('delete_root_state', 'file_deleter'), show_progress=False)
|
||||||
shared.gradio['save_character_confirm'].click(handle_save_character_confirm_click, gradio('name2', 'greeting', 'context', 'character_picture', 'save_character_filename'), gradio('character_menu', 'character_saver'), show_progress=False)
|
shared.gradio['save_character_confirm'].click(handle_save_character_confirm_click, gradio('name2', 'greeting', 'context', 'character_picture', 'save_character_filename'), gradio('character_menu', 'character_saver'), show_progress=False)
|
||||||
shared.gradio['delete_character_confirm'].click(handle_delete_character_confirm_click, gradio('character_menu'), gradio('character_menu', 'character_deleter'), show_progress=False)
|
shared.gradio['delete_character_confirm'].click(handle_delete_character_confirm_click, gradio('character_menu'), gradio('character_menu', 'character_deleter'), show_progress=False)
|
||||||
|
|
||||||
|
|
@ -69,10 +88,17 @@ def create_event_handlers():
|
||||||
shared.gradio['save_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_saver'), show_progress=False)
|
shared.gradio['save_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_saver'), show_progress=False)
|
||||||
shared.gradio['delete_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_deleter'), show_progress=False)
|
shared.gradio['delete_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_deleter'), show_progress=False)
|
||||||
|
|
||||||
|
# User save/delete event handlers
|
||||||
|
shared.gradio['save_user_confirm'].click(handle_save_user_confirm_click, gradio('name1', 'user_bio', 'your_picture', 'save_user_filename'), gradio('user_menu', 'user_saver'), show_progress=False)
|
||||||
|
shared.gradio['delete_user_confirm'].click(handle_delete_user_confirm_click, gradio('user_menu'), gradio('user_menu', 'user_deleter'), show_progress=False)
|
||||||
|
shared.gradio['save_user_cancel'].click(lambda: gr.update(visible=False), None, gradio('user_saver'), show_progress=False)
|
||||||
|
shared.gradio['delete_user_cancel'].click(lambda: gr.update(visible=False), None, gradio('user_deleter'), show_progress=False)
|
||||||
|
|
||||||
|
|
||||||
def handle_save_preset_confirm_click(filename, contents):
|
def handle_save_preset_confirm_click(filename, contents):
|
||||||
try:
|
try:
|
||||||
utils.save_file(f"user_data/presets/{filename}.yaml", contents)
|
filename = sanitize_filename(filename)
|
||||||
|
utils.save_file(str(shared.user_data_dir / "presets" / f"{filename}.yaml"), contents)
|
||||||
available_presets = utils.get_available_presets()
|
available_presets = utils.get_available_presets()
|
||||||
output = gr.update(choices=available_presets, value=filename)
|
output = gr.update(choices=available_presets, value=filename)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -85,22 +111,30 @@ def handle_save_preset_confirm_click(filename, contents):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def handle_save_confirm_click(root, filename, contents):
|
def handle_save_confirm_click(root_state, filename, contents):
|
||||||
try:
|
try:
|
||||||
utils.save_file(root + filename, contents)
|
if root_state is None:
|
||||||
|
return None, gr.update(visible=False)
|
||||||
|
|
||||||
|
filename = sanitize_filename(filename)
|
||||||
|
utils.save_file(root_state + filename, contents)
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
return gr.update(visible=False)
|
return None, gr.update(visible=False)
|
||||||
|
|
||||||
|
|
||||||
def handle_delete_confirm_click(root, filename):
|
def handle_delete_confirm_click(root_state, filename):
|
||||||
try:
|
try:
|
||||||
utils.delete_file(root + filename)
|
if root_state is None:
|
||||||
|
return None, gr.update(visible=False)
|
||||||
|
|
||||||
|
filename = sanitize_filename(filename)
|
||||||
|
utils.delete_file(root_state + filename)
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
return gr.update(visible=False)
|
return None, gr.update(visible=False)
|
||||||
|
|
||||||
|
|
||||||
def handle_save_character_confirm_click(name2, greeting, context, character_picture, filename):
|
def handle_save_character_confirm_click(name2, greeting, context, character_picture, filename):
|
||||||
|
|
@ -143,25 +177,61 @@ def handle_save_preset_click(state):
|
||||||
|
|
||||||
|
|
||||||
def handle_delete_preset_click(preset):
|
def handle_delete_preset_click(preset):
|
||||||
|
root = str(shared.user_data_dir / "presets") + "/"
|
||||||
return [
|
return [
|
||||||
f"{preset}.yaml",
|
f"{preset}.yaml",
|
||||||
"user_data/presets/",
|
root,
|
||||||
|
root,
|
||||||
gr.update(visible=True)
|
gr.update(visible=True)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def handle_save_grammar_click(grammar_string):
|
def handle_save_grammar_click(grammar_string):
|
||||||
|
root = str(shared.user_data_dir / "grammars") + "/"
|
||||||
return [
|
return [
|
||||||
grammar_string,
|
grammar_string,
|
||||||
"My Fancy Grammar.gbnf",
|
"My Fancy Grammar.gbnf",
|
||||||
"user_data/grammars/",
|
root,
|
||||||
|
root,
|
||||||
gr.update(visible=True)
|
gr.update(visible=True)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def handle_delete_grammar_click(grammar_file):
|
def handle_delete_grammar_click(grammar_file):
|
||||||
|
root = str(shared.user_data_dir / "grammars") + "/"
|
||||||
return [
|
return [
|
||||||
grammar_file,
|
grammar_file,
|
||||||
"user_data/grammars/",
|
root,
|
||||||
|
root,
|
||||||
gr.update(visible=True)
|
gr.update(visible=True)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def handle_save_user_confirm_click(name1, user_bio, your_picture, filename):
|
||||||
|
try:
|
||||||
|
chat.save_user(name1, user_bio, your_picture, filename)
|
||||||
|
available_users = utils.get_available_users()
|
||||||
|
output = gr.update(choices=available_users, value=filename)
|
||||||
|
except Exception:
|
||||||
|
output = gr.update()
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
return [
|
||||||
|
output,
|
||||||
|
gr.update(visible=False)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def handle_delete_user_confirm_click(user):
|
||||||
|
try:
|
||||||
|
index = str(utils.get_available_users().index(user))
|
||||||
|
chat.delete_user(user)
|
||||||
|
output = chat.update_user_menu_after_deletion(index)
|
||||||
|
except Exception:
|
||||||
|
output = gr.update()
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
return [
|
||||||
|
output,
|
||||||
|
gr.update(visible=False)
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -138,7 +138,7 @@ def save_generated_images(images, state, actual_seed):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||||
folder_path = os.path.join("user_data", "image_outputs", date_str)
|
folder_path = str(shared.user_data_dir / "image_outputs" / date_str)
|
||||||
os.makedirs(folder_path, exist_ok=True)
|
os.makedirs(folder_path, exist_ok=True)
|
||||||
|
|
||||||
metadata = build_generation_metadata(state, actual_seed)
|
metadata = build_generation_metadata(state, actual_seed)
|
||||||
|
|
@ -214,7 +214,7 @@ def get_all_history_images(force_refresh=False):
|
||||||
"""Get all history images sorted by modification time (newest first). Uses caching."""
|
"""Get all history images sorted by modification time (newest first). Uses caching."""
|
||||||
global _image_cache, _cache_timestamp
|
global _image_cache, _cache_timestamp
|
||||||
|
|
||||||
output_dir = os.path.join("user_data", "image_outputs")
|
output_dir = str(shared.user_data_dir / "image_outputs")
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -728,6 +728,8 @@ def generate_prompt_variation(state):
|
||||||
variation = variation.rsplit("</think>", 1)[1]
|
variation = variation.rsplit("</think>", 1)[1]
|
||||||
elif "<|start|>assistant<|channel|>final<|message|>" in variation:
|
elif "<|start|>assistant<|channel|>final<|message|>" in variation:
|
||||||
variation = variation.rsplit("<|start|>assistant<|channel|>final<|message|>", 1)[1]
|
variation = variation.rsplit("<|start|>assistant<|channel|>final<|message|>", 1)[1]
|
||||||
|
elif "<|channel|>final<|message|>" in variation:
|
||||||
|
variation = variation.rsplit("<|channel|>final<|message|>", 1)[1]
|
||||||
elif "</seed:think>" in variation:
|
elif "</seed:think>" in variation:
|
||||||
variation = variation.rsplit("</seed:think>", 1)[1]
|
variation = variation.rsplit("</seed:think>", 1)[1]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,11 +41,12 @@ def create_ui():
|
||||||
gr.Markdown("## Main options")
|
gr.Markdown("## Main options")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['gpu_layers'] = gr.Slider(label="gpu-layers", minimum=0, maximum=get_initial_gpu_layers_max(), step=1, value=shared.args.gpu_layers, info='Must be greater than 0 for the GPU to be used. ⚠️ Lower this value if you can\'t load the model.')
|
shared.gradio['gpu_layers'] = gr.Slider(label="gpu-layers", minimum=-1, maximum=get_initial_gpu_layers_max(), step=1, value=shared.args.gpu_layers, info='Number of layers to offload to the GPU. -1 = auto.')
|
||||||
shared.gradio['ctx_size'] = gr.Slider(label='ctx-size', minimum=256, maximum=131072, step=256, value=shared.args.ctx_size, info='Context length. Common values: 4096, 8192, 16384, 32768, 65536, 131072.')
|
shared.gradio['ctx_size'] = gr.Slider(label='ctx-size', minimum=0, maximum=1048576, step=1024, value=shared.args.ctx_size, info='Context length. 0 = auto for llama.cpp (requires gpu-layers=-1), 8192 for other loaders. Common values: 4096, 8192, 16384, 32768, 65536, 131072.')
|
||||||
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
|
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
|
||||||
shared.gradio['attn_implementation'] = gr.Dropdown(label="attn-implementation", choices=['sdpa', 'eager', 'flash_attention_2'], value=shared.args.attn_implementation, info='Attention implementation.')
|
shared.gradio['attn_implementation'] = gr.Dropdown(label="attn-implementation", choices=['sdpa', 'eager', 'flash_attention_2'], value=shared.args.attn_implementation, info='Attention implementation.')
|
||||||
shared.gradio['cache_type'] = gr.Dropdown(label="cache-type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q7', 'q6', 'q5', 'q4', 'q3', 'q2'], value=shared.args.cache_type, allow_custom_value=True, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4; ExLlamaV3 - fp16, q2 to q8. For ExLlamaV3, you can type custom combinations for separate k/v bits (e.g. q4_q8).')
|
shared.gradio['cache_type'] = gr.Dropdown(label="cache-type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q7', 'q6', 'q5', 'q4', 'q3', 'q2'], value=shared.args.cache_type, allow_custom_value=True, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8. For ExLlamaV3, you can type custom combinations for separate k/v bits (e.g. q4_q8).')
|
||||||
|
shared.gradio['fit_target'] = gr.Textbox(label='fit-target', value=shared.args.fit_target, info='Target VRAM margin per device for auto GPU layers (MiB). Comma-separated list for multiple devices.')
|
||||||
shared.gradio['tp_backend'] = gr.Dropdown(label="tp-backend", choices=['native', 'nccl'], value=shared.args.tp_backend, info='The backend for tensor parallelism.')
|
shared.gradio['tp_backend'] = gr.Dropdown(label="tp-backend", choices=['native', 'nccl'], value=shared.args.tp_backend, info='The backend for tensor parallelism.')
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
|
@ -55,32 +56,43 @@ def create_ui():
|
||||||
shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
|
shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
|
||||||
shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit)
|
shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit)
|
||||||
shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant, info='Used by load-in-4bit.')
|
shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant, info='Used by load-in-4bit.')
|
||||||
shared.gradio['autosplit'] = gr.Checkbox(label="autosplit", value=shared.args.autosplit, info='Automatically split the model tensors across the available GPUs.')
|
|
||||||
shared.gradio['enable_tp'] = gr.Checkbox(label="enable_tp", value=shared.args.enable_tp, info='Enable tensor parallelism (TP).')
|
shared.gradio['enable_tp'] = gr.Checkbox(label="enable_tp", value=shared.args.enable_tp, info='Enable tensor parallelism (TP).')
|
||||||
shared.gradio['cpp_runner'] = gr.Checkbox(label="cpp-runner", value=shared.args.cpp_runner, info='Enable inference with ModelRunnerCpp, which is faster than the default ModelRunner.')
|
shared.gradio['tensorrt_llm_info'] = gr.Markdown(
|
||||||
shared.gradio['tensorrt_llm_info'] = gr.Markdown('* TensorRT-LLM has to be installed manually in a separate Python 3.10 environment at the moment. For a guide, consult the description of [this PR](https://github.com/oobabooga/text-generation-webui/pull/5715). \n\n* `ctx_size` is only used when `cpp-runner` is checked.\n\n* `cpp_runner` does not support streaming at the moment.')
|
'* TensorRT-LLM has to be installed manually: `pip install tensorrt_llm==1.1.0 --extra-index-url https://pypi.nvidia.com`.\n\n'
|
||||||
|
'* You can load either a pre-built TensorRT engine or a regular HF model. '
|
||||||
|
'HF models will be compiled to a TensorRT engine automatically on each load (this can take a while).'
|
||||||
|
)
|
||||||
|
|
||||||
# Multimodal
|
# Multimodal
|
||||||
with gr.Accordion("Multimodal (vision)", open=False, elem_classes='tgw-accordion') as shared.gradio['mmproj_accordion']:
|
with gr.Accordion("Multimodal (vision)", open=False, elem_classes='tgw-accordion') as shared.gradio['mmproj_accordion']:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['mmproj'] = gr.Dropdown(label="mmproj file", choices=utils.get_available_mmproj(), value=lambda: shared.args.mmproj or 'None', elem_classes='slim-dropdown', info='Select a file that matches your model. Must be placed in user_data/mmproj/', interactive=not mu)
|
shared.gradio['mmproj'] = gr.Dropdown(label="mmproj file", choices=utils.get_available_mmproj(), value=lambda: shared.args.mmproj or 'None', elem_classes='slim-dropdown', info=f'Select a file that matches your model. Must be placed in {shared.user_data_dir}/mmproj/', interactive=not mu)
|
||||||
ui.create_refresh_button(shared.gradio['mmproj'], lambda: None, lambda: {'choices': utils.get_available_mmproj()}, 'refresh-button', interactive=not mu)
|
ui.create_refresh_button(shared.gradio['mmproj'], lambda: None, lambda: {'choices': utils.get_available_mmproj()}, 'refresh-button', interactive=not mu)
|
||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']:
|
with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']:
|
||||||
|
shared.gradio['draft_max'] = gr.Number(label="draft-max", precision=0, step=1, value=shared.args.draft_max, info='Maximum number of tokens to draft for speculative decoding. Recommended: 4 for draft model, 64 for n-gram.')
|
||||||
|
|
||||||
|
gr.Markdown('#### Draft model')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['model_draft'] = gr.Dropdown(label="model-draft", choices=['None'] + utils.get_available_models(), value=lambda: shared.args.model_draft, elem_classes='slim-dropdown', info='Draft model. Speculative decoding only works with models sharing the same vocabulary (e.g., same model family).', interactive=not mu)
|
shared.gradio['model_draft'] = gr.Dropdown(label="model-draft", choices=['None'] + utils.get_available_models(), value=lambda: shared.args.model_draft, elem_classes='slim-dropdown', info='Draft model. Must share the same vocabulary as the main model.', interactive=not mu)
|
||||||
ui.create_refresh_button(shared.gradio['model_draft'], lambda: None, lambda: {'choices': ['None'] + utils.get_available_models()}, 'refresh-button', interactive=not mu)
|
ui.create_refresh_button(shared.gradio['model_draft'], lambda: None, lambda: {'choices': ['None'] + utils.get_available_models()}, 'refresh-button', interactive=not mu)
|
||||||
|
|
||||||
shared.gradio['gpu_layers_draft'] = gr.Slider(label="gpu-layers-draft", minimum=0, maximum=256, value=shared.args.gpu_layers_draft, info='Number of layers to offload to the GPU for the draft model.')
|
shared.gradio['gpu_layers_draft'] = gr.Slider(label="gpu-layers-draft", minimum=0, maximum=256, value=shared.args.gpu_layers_draft, info='Number of layers to offload to the GPU for the draft model.')
|
||||||
shared.gradio['draft_max'] = gr.Number(label="draft-max", precision=0, step=1, value=shared.args.draft_max, info='Number of tokens to draft for speculative decoding. Recommended value: 4.')
|
|
||||||
shared.gradio['device_draft'] = gr.Textbox(label="device-draft", value=shared.args.device_draft, info='Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1')
|
shared.gradio['device_draft'] = gr.Textbox(label="device-draft", value=shared.args.device_draft, info='Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1')
|
||||||
shared.gradio['ctx_size_draft'] = gr.Number(label="ctx-size-draft", precision=0, step=256, value=shared.args.ctx_size_draft, info='Size of the prompt context for the draft model. If 0, uses the same as the main model.')
|
shared.gradio['ctx_size_draft'] = gr.Number(label="ctx-size-draft", precision=0, step=256, value=shared.args.ctx_size_draft, info='Size of the prompt context for the draft model. If 0, uses the same as the main model.')
|
||||||
|
|
||||||
|
shared.gradio['ngram_header'] = gr.Markdown('#### N-gram (draftless)')
|
||||||
|
shared.gradio['spec_type'] = gr.Dropdown(label="spec-type", choices=['none', 'ngram-mod', 'ngram-simple', 'ngram-map-k', 'ngram-map-k4v', 'ngram-cache'], value=shared.args.spec_type, info='Draftless speculative decoding type. Recommended: ngram-mod.')
|
||||||
|
shared.gradio['spec_ngram_size_n'] = gr.Number(label="spec-ngram-size-n", precision=0, step=1, value=shared.args.spec_ngram_size_n, info='N-gram lookup size for speculative decoding.', visible=shared.args.spec_type != 'none')
|
||||||
|
shared.gradio['spec_ngram_size_m'] = gr.Number(label="spec-ngram-size-m", precision=0, step=1, value=shared.args.spec_ngram_size_m, info='Draft n-gram size for speculative decoding.', visible=shared.args.spec_type != 'none')
|
||||||
|
shared.gradio['spec_ngram_min_hits'] = gr.Number(label="spec-ngram-min-hits", precision=0, step=1, value=shared.args.spec_ngram_min_hits, info='Minimum n-gram hits for ngram-map speculative decoding.', visible=shared.args.spec_type != 'none')
|
||||||
|
|
||||||
gr.Markdown("## Other options")
|
gr.Markdown("## Other options")
|
||||||
with gr.Accordion("See more options", open=False, elem_classes='tgw-accordion'):
|
with gr.Accordion("See more options", open=False, elem_classes='tgw-accordion'):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
shared.gradio['parallel'] = gr.Slider(label="parallel", minimum=1, step=1, maximum=64, value=shared.args.parallel, info='Number of parallel request slots for the API. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set ctx_size to 32768.')
|
||||||
shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=256, value=shared.args.threads)
|
shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=256, value=shared.args.threads)
|
||||||
shared.gradio['threads_batch'] = gr.Slider(label="threads_batch", minimum=0, step=1, maximum=256, value=shared.args.threads_batch)
|
shared.gradio['threads_batch'] = gr.Slider(label="threads_batch", minimum=0, step=1, maximum=256, value=shared.args.threads_batch)
|
||||||
shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size)
|
shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size)
|
||||||
|
|
@ -88,12 +100,8 @@ def create_ui():
|
||||||
shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')
|
shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')
|
||||||
shared.gradio['extra_flags'] = gr.Textbox(label='extra-flags', info='Additional flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"', value=shared.args.extra_flags)
|
shared.gradio['extra_flags'] = gr.Textbox(label='extra-flags', info='Additional flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"', value=shared.args.extra_flags)
|
||||||
shared.gradio['cpu_memory'] = gr.Number(label="Maximum CPU memory in GiB. Use this for CPU offloading.", value=shared.args.cpu_memory)
|
shared.gradio['cpu_memory'] = gr.Number(label="Maximum CPU memory in GiB. Use this for CPU offloading.", value=shared.args.cpu_memory)
|
||||||
shared.gradio['alpha_value'] = gr.Number(label='alpha_value', value=shared.args.alpha_value, precision=2, info='Positional embeddings alpha factor for NTK RoPE scaling. Recommended values (NTKv1): 1.75 for 1.5x context, 2.5 for 2x context. Use either this or compress_pos_emb, not both.')
|
|
||||||
shared.gradio['rope_freq_base'] = gr.Number(label='rope_freq_base', value=shared.args.rope_freq_base, precision=0, info='Positional embeddings frequency base for NTK RoPE scaling. Related to alpha_value by rope_freq_base = 10000 * alpha_value ^ (64 / 63). 0 = from model.')
|
|
||||||
shared.gradio['compress_pos_emb'] = gr.Number(label='compress_pos_emb', value=shared.args.compress_pos_emb, precision=2, info='Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.')
|
|
||||||
shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype, info='Used by load-in-4bit.')
|
shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype, info='Used by load-in-4bit.')
|
||||||
shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type, info='Used by load-in-4bit.')
|
shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type, info='Used by load-in-4bit.')
|
||||||
shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.')
|
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='Use PyTorch in CPU mode.')
|
shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='Use PyTorch in CPU mode.')
|
||||||
|
|
@ -104,9 +112,6 @@ def create_ui():
|
||||||
shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock)
|
shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock)
|
||||||
shared.gradio['numa'] = gr.Checkbox(label="numa", value=shared.args.numa, info='NUMA support can help on some systems with non-uniform memory access.')
|
shared.gradio['numa'] = gr.Checkbox(label="numa", value=shared.args.numa, info='NUMA support can help on some systems with non-uniform memory access.')
|
||||||
shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16)
|
shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16)
|
||||||
shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn)
|
|
||||||
shared.gradio['no_xformers'] = gr.Checkbox(label="no_xformers", value=shared.args.no_xformers)
|
|
||||||
shared.gradio['no_sdpa'] = gr.Checkbox(label="no_sdpa", value=shared.args.no_sdpa)
|
|
||||||
shared.gradio['cfg_cache'] = gr.Checkbox(label="cfg-cache", value=shared.args.cfg_cache, info='Necessary to use CFG with this loader.')
|
shared.gradio['cfg_cache'] = gr.Checkbox(label="cfg-cache", value=shared.args.cfg_cache, info='Necessary to use CFG with this loader.')
|
||||||
shared.gradio['no_use_fast'] = gr.Checkbox(label="no_use_fast", value=shared.args.no_use_fast, info='Set use_fast=False while loading the tokenizer.')
|
shared.gradio['no_use_fast'] = gr.Checkbox(label="no_use_fast", value=shared.args.no_use_fast, info='Set use_fast=False while loading the tokenizer.')
|
||||||
if not shared.args.portable:
|
if not shared.args.portable:
|
||||||
|
|
@ -157,28 +162,35 @@ def create_event_handlers():
|
||||||
handle_load_model_event_final, gradio('truncation_length', 'loader', 'interface_state'), gradio('truncation_length', 'filter_by_loader'), show_progress=False)
|
handle_load_model_event_final, gradio('truncation_length', 'loader', 'interface_state'), gradio('truncation_length', 'filter_by_loader'), show_progress=False)
|
||||||
|
|
||||||
shared.gradio['unload_model'].click(handle_unload_model_click, None, gradio('model_status'), show_progress=False).then(
|
shared.gradio['unload_model'].click(handle_unload_model_click, None, gradio('model_status'), show_progress=False).then(
|
||||||
partial(update_gpu_layers_and_vram, auto_adjust=True), gradio('loader', 'model_menu', 'gpu_layers', 'ctx_size', 'cache_type'), gradio('vram_info', 'gpu_layers'), show_progress=False)
|
update_gpu_layers_and_vram, gradio('loader', 'model_menu', 'gpu_layers', 'ctx_size', 'cache_type'), gradio('vram_info'), show_progress=False)
|
||||||
|
|
||||||
shared.gradio['save_model_settings'].click(
|
shared.gradio['save_model_settings'].click(
|
||||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
save_model_settings, gradio('model_menu', 'interface_state'), gradio('model_status'), show_progress=False)
|
save_model_settings, gradio('model_menu', 'interface_state'), gradio('model_status'), show_progress=False)
|
||||||
|
|
||||||
# For ctx_size and cache_type - auto-adjust GPU layers
|
# For ctx_size and cache_type - update VRAM display
|
||||||
for param in ['ctx_size', 'cache_type']:
|
for param in ['ctx_size', 'cache_type']:
|
||||||
shared.gradio[param].change(
|
shared.gradio[param].change(
|
||||||
partial(update_gpu_layers_and_vram, auto_adjust=True),
|
update_gpu_layers_and_vram,
|
||||||
gradio('loader', 'model_menu', 'gpu_layers', 'ctx_size', 'cache_type'),
|
gradio('loader', 'model_menu', 'gpu_layers', 'ctx_size', 'cache_type'),
|
||||||
gradio('vram_info', 'gpu_layers'), show_progress=False)
|
gradio('vram_info'), show_progress=False)
|
||||||
|
|
||||||
# For manual gpu_layers changes - only update VRAM
|
# For manual gpu_layers changes - only update VRAM
|
||||||
shared.gradio['gpu_layers'].change(
|
shared.gradio['gpu_layers'].change(
|
||||||
partial(update_gpu_layers_and_vram, auto_adjust=False),
|
update_gpu_layers_and_vram,
|
||||||
gradio('loader', 'model_menu', 'gpu_layers', 'ctx_size', 'cache_type'),
|
gradio('loader', 'model_menu', 'gpu_layers', 'ctx_size', 'cache_type'),
|
||||||
gradio('vram_info'), show_progress=False)
|
gradio('vram_info'), show_progress=False)
|
||||||
|
|
||||||
if not shared.args.portable:
|
if not shared.args.portable:
|
||||||
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, gradio('lora_menu'), gradio('model_status'), show_progress=False)
|
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, gradio('lora_menu'), gradio('model_status'), show_progress=False)
|
||||||
|
|
||||||
|
shared.gradio['spec_type'].change(
|
||||||
|
lambda x: [gr.update(visible=x != 'none')] * 3,
|
||||||
|
gradio('spec_type'),
|
||||||
|
gradio('spec_ngram_size_n', 'spec_ngram_size_m', 'spec_ngram_min_hits'),
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
shared.gradio['download_model_button'].click(download_model_wrapper, gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True)
|
shared.gradio['download_model_button'].click(download_model_wrapper, gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True)
|
||||||
shared.gradio['get_file_list'].click(partial(download_model_wrapper, return_links=True), gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True)
|
shared.gradio['get_file_list'].click(partial(download_model_wrapper, return_links=True), gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True)
|
||||||
shared.gradio['customized_template_submit'].click(save_instruction_template, gradio('model_menu', 'customized_template'), gradio('model_status'), show_progress=True)
|
shared.gradio['customized_template_submit'].click(save_instruction_template, gradio('model_menu', 'customized_template'), gradio('model_status'), show_progress=True)
|
||||||
|
|
@ -209,7 +221,7 @@ def load_model_wrapper(selected_model, loader, autoload=False):
|
||||||
yield f"Successfully loaded `{selected_model}`."
|
yield f"Successfully loaded `{selected_model}`."
|
||||||
else:
|
else:
|
||||||
yield f"Failed to load `{selected_model}`."
|
yield f"Failed to load `{selected_model}`."
|
||||||
except:
|
except Exception:
|
||||||
exc = traceback.format_exc()
|
exc = traceback.format_exc()
|
||||||
logger.error('Failed to load the model.')
|
logger.error('Failed to load the model.')
|
||||||
print(exc)
|
print(exc)
|
||||||
|
|
@ -303,9 +315,9 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
|
||||||
model_dir=shared.args.model_dir if shared.args.model_dir != shared.args_defaults.model_dir else None
|
model_dir=shared.args.model_dir if shared.args.model_dir != shared.args_defaults.model_dir else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if output_folder == Path("user_data/models"):
|
if output_folder == shared.user_data_dir / "models":
|
||||||
output_folder = Path(shared.args.model_dir)
|
output_folder = Path(shared.args.model_dir)
|
||||||
elif output_folder == Path("user_data/loras"):
|
elif output_folder == shared.user_data_dir / "loras":
|
||||||
output_folder = Path(shared.args.lora_dir)
|
output_folder = Path(shared.args.lora_dir)
|
||||||
|
|
||||||
if check:
|
if check:
|
||||||
|
|
@ -373,8 +385,12 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
|
||||||
def update_truncation_length(current_length, state):
|
def update_truncation_length(current_length, state):
|
||||||
if 'loader' in state:
|
if 'loader' in state:
|
||||||
if state['loader'].lower().startswith('exllama') or state['loader'] == 'llama.cpp':
|
if state['loader'].lower().startswith('exllama') or state['loader'] == 'llama.cpp':
|
||||||
|
if state['ctx_size'] > 0:
|
||||||
return state['ctx_size']
|
return state['ctx_size']
|
||||||
|
|
||||||
|
# ctx_size == 0 means auto: use the actual value from the server
|
||||||
|
return shared.settings['truncation_length']
|
||||||
|
|
||||||
return current_length
|
return current_length
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -386,8 +402,6 @@ def get_initial_vram_info():
|
||||||
shared.args.gpu_layers,
|
shared.args.gpu_layers,
|
||||||
shared.args.ctx_size,
|
shared.args.ctx_size,
|
||||||
shared.args.cache_type,
|
shared.args.cache_type,
|
||||||
auto_adjust=False,
|
|
||||||
for_ui=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return "<div id=\"vram-info\"'>Estimated VRAM to load the model:</div>"
|
return "<div id=\"vram-info\"'>Estimated VRAM to load the model:</div>"
|
||||||
|
|
@ -396,7 +410,7 @@ def get_initial_vram_info():
|
||||||
def get_initial_gpu_layers_max():
|
def get_initial_gpu_layers_max():
|
||||||
if shared.model_name != 'None' and shared.args.loader == 'llama.cpp':
|
if shared.model_name != 'None' and shared.args.loader == 'llama.cpp':
|
||||||
model_settings = get_model_metadata(shared.model_name)
|
model_settings = get_model_metadata(shared.model_name)
|
||||||
return model_settings.get('max_gpu_layers', model_settings.get('gpu_layers', 256))
|
return model_settings.get('max_gpu_layers', 256)
|
||||||
|
|
||||||
return 256
|
return 256
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -194,7 +194,7 @@ def handle_new_prompt():
|
||||||
new_name = utils.current_time()
|
new_name = utils.current_time()
|
||||||
|
|
||||||
# Create the new prompt file
|
# Create the new prompt file
|
||||||
prompt_path = Path("user_data/logs/notebook") / f"{new_name}.txt"
|
prompt_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt"
|
||||||
prompt_path.parent.mkdir(parents=True, exist_ok=True)
|
prompt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
prompt_path.write_text("In this story,", encoding='utf-8')
|
prompt_path.write_text("In this story,", encoding='utf-8')
|
||||||
|
|
||||||
|
|
@ -205,15 +205,15 @@ def handle_delete_prompt_confirm_notebook(prompt_name):
|
||||||
available_prompts = utils.get_available_prompts()
|
available_prompts = utils.get_available_prompts()
|
||||||
current_index = available_prompts.index(prompt_name) if prompt_name in available_prompts else 0
|
current_index = available_prompts.index(prompt_name) if prompt_name in available_prompts else 0
|
||||||
|
|
||||||
(Path("user_data/logs/notebook") / f"{prompt_name}.txt").unlink(missing_ok=True)
|
(shared.user_data_dir / "logs" / "notebook" / f"{prompt_name}.txt").unlink(missing_ok=True)
|
||||||
available_prompts = utils.get_available_prompts()
|
available_prompts = utils.get_available_prompts()
|
||||||
|
|
||||||
if available_prompts:
|
if available_prompts:
|
||||||
new_value = available_prompts[min(current_index, len(available_prompts) - 1)]
|
new_value = available_prompts[min(current_index, len(available_prompts) - 1)]
|
||||||
else:
|
else:
|
||||||
new_value = utils.current_time()
|
new_value = utils.current_time()
|
||||||
Path("user_data/logs/notebook").mkdir(parents=True, exist_ok=True)
|
(shared.user_data_dir / "logs" / "notebook").mkdir(parents=True, exist_ok=True)
|
||||||
(Path("user_data/logs/notebook") / f"{new_value}.txt").write_text("In this story,")
|
(shared.user_data_dir / "logs" / "notebook" / f"{new_value}.txt").write_text("In this story,")
|
||||||
available_prompts = [new_value]
|
available_prompts = [new_value]
|
||||||
|
|
||||||
return [
|
return [
|
||||||
|
|
@ -233,8 +233,8 @@ def handle_rename_prompt_click_notebook(current_name):
|
||||||
|
|
||||||
|
|
||||||
def handle_rename_prompt_confirm_notebook(new_name, current_name):
|
def handle_rename_prompt_confirm_notebook(new_name, current_name):
|
||||||
old_path = Path("user_data/logs/notebook") / f"{current_name}.txt"
|
old_path = shared.user_data_dir / "logs" / "notebook" / f"{current_name}.txt"
|
||||||
new_path = Path("user_data/logs/notebook") / f"{new_name}.txt"
|
new_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt"
|
||||||
|
|
||||||
if old_path.exists() and not new_path.exists():
|
if old_path.exists() and not new_path.exists():
|
||||||
old_path.rename(new_path)
|
old_path.rename(new_path)
|
||||||
|
|
@ -250,7 +250,7 @@ def handle_rename_prompt_confirm_notebook(new_name, current_name):
|
||||||
def autosave_prompt(text, prompt_name):
|
def autosave_prompt(text, prompt_name):
|
||||||
"""Automatically save the text to the selected prompt file"""
|
"""Automatically save the text to the selected prompt file"""
|
||||||
if prompt_name and text.strip():
|
if prompt_name and text.strip():
|
||||||
prompt_path = Path("user_data/logs/notebook") / f"{prompt_name}.txt"
|
prompt_path = shared.user_data_dir / "logs" / "notebook" / f"{prompt_name}.txt"
|
||||||
prompt_path.parent.mkdir(parents=True, exist_ok=True)
|
prompt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
prompt_path.write_text(text, encoding='utf-8')
|
prompt_path.write_text(text, encoding='utf-8')
|
||||||
|
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue