llama.cpp: Add speculative decoding (#6891)

This commit is contained in:
oobabooga 2025-04-23 20:10:16 -03:00 committed by GitHub
parent 9424ba17c8
commit e99c20bcb0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 61 additions and 2 deletions

View file

@ -6,6 +6,7 @@ import subprocess
import sys
import threading
import time
from pathlib import Path
import llama_cpp_binaries
import requests
@ -281,6 +282,25 @@ class LlamaServer:
cmd += ["--rope-freq-scale", str(1.0 / shared.args.compress_pos_emb)]
if shared.args.rope_freq_base > 0:
cmd += ["--rope-freq-base", str(shared.args.rope_freq_base)]
if shared.args.model_draft not in [None, 'None']:
path = Path(shared.args.model_draft)
if not path.exists():
path = Path(f'{shared.args.model_dir}/{shared.args.model_draft}')
if path.is_file():
model_file = path
else:
model_file = sorted(Path(f'{shared.args.model_dir}/{shared.args.model_draft}').glob('*.gguf'))[0]
cmd += ["--model-draft", model_file]
if shared.args.draft_max > 0:
cmd += ["--draft-max", str(shared.args.draft_max)]
if shared.args.gpu_layers_draft > 0:
cmd += ["--gpu-layers-draft", str(shared.args.gpu_layers_draft)]
if shared.args.device_draft:
cmd += ["--device-draft", shared.args.device_draft]
if shared.args.ctx_size_draft > 0:
cmd += ["--ctx-size-draft", str(shared.args.ctx_size_draft)]
env = os.environ.copy()
if os.name == 'posix':