diff --git a/modules/exllamav3.py b/modules/exllamav3.py index d884bbf7..ea1f3dc9 100644 --- a/modules/exllamav3.py +++ b/modules/exllamav3.py @@ -9,6 +9,7 @@ from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant from exllamav3.generator import Job from exllamav3.generator.sampler import ( CustomSampler, + SS_AdaptiveP, SS_Argmax, SS_MinP, SS_PresFreqP, @@ -158,7 +159,7 @@ class Exllamav3Model: tokenizer=tokenizer, draft_model=draft_model, draft_cache=draft_cache, - num_speculative_tokens=shared.args.draft_max if draft_model is not None else 0, + num_draft_tokens=shared.args.draft_max if draft_model is not None else 0, ) result = cls() @@ -302,7 +303,11 @@ class Exllamav3Model: ordered_samplers = sorted(unordered_samplers, key=custom_sort_key) # 5. Add the final sampling stage and build the sampler - ordered_samplers.append(SS_Sample()) + if state.get('adaptive_target', 0) > 0: + ordered_samplers.append(SS_AdaptiveP(state['adaptive_target'], state['adaptive_decay'])) + else: + ordered_samplers.append(SS_Sample()) + sampler = CustomSampler(ordered_samplers) # Encode prompt with embeddings (ExLlamaV3-specific) @@ -329,12 +334,14 @@ class Exllamav3Model: if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None: stop_conditions.append(self.tokenizer.eos_token_id) + seed = state.get('seed', -1) job = Job( input_ids=input_ids, max_new_tokens=max_new_tokens, decode_special_tokens=not state['skip_special_tokens'], embeddings=image_embeddings if image_embeddings else None, sampler=sampler, + seed=seed if seed >= 0 else None, stop_conditions=stop_conditions if stop_conditions else None, ) diff --git a/modules/loaders.py b/modules/loaders.py index 94c9c593..c8b499b9 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -276,6 +276,8 @@ loaders_samplers = { 'min_p', 'top_p', 'top_k', + 'adaptive_target', + 'adaptive_decay', 'repetition_penalty', 'frequency_penalty', 'presence_penalty',