diff --git a/modules/chat.py b/modules/chat.py index 66ab8c74..fd949907 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -11,6 +11,7 @@ from pathlib import Path import gradio as gr import yaml +from jinja2.ext import loopcontrols from jinja2.sandbox import ImmutableSandboxedEnvironment from PIL import Image @@ -35,7 +36,11 @@ def strftime_now(format): return datetime.now().strftime(format) -jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True) +jinja_env = ImmutableSandboxedEnvironment( + trim_blocks=True, + lstrip_blocks=True, + extensions=[loopcontrols] +) jinja_env.globals["strftime_now"] = strftime_now diff --git a/modules/llama_cpp_python_hijack.py b/modules/llama_cpp_python_hijack.py index f3872a74..c03c28a7 100644 --- a/modules/llama_cpp_python_hijack.py +++ b/modules/llama_cpp_python_hijack.py @@ -121,5 +121,45 @@ def monkey_patch_llama_cpp_python(lib): lib.Llama.original_generate = lib.Llama.generate lib.Llama.generate = my_generate + # Also patch Jinja2ChatFormatter to handle loop controls + if hasattr(lib, 'llama_chat_format') and hasattr(lib.llama_chat_format, 'Jinja2ChatFormatter'): + Formatter = lib.llama_chat_format.Jinja2ChatFormatter + + if not getattr(Formatter, '_is_patched', False): + def patched_init(self, *args, **kwargs): + # Extract parameters from args or kwargs + if args: + self.template = args[0] + self.eos_token = args[1] if len(args) > 1 else kwargs.get('eos_token') + self.bos_token = args[2] if len(args) > 2 else kwargs.get('bos_token') + self.add_generation_prompt = args[3] if len(args) > 3 else kwargs.get('add_generation_prompt', True) + self.stop_token_ids = args[4] if len(args) > 4 else kwargs.get('stop_token_ids') + else: + self.template = kwargs.get('template') + self.eos_token = kwargs.get('eos_token') + self.bos_token = kwargs.get('bos_token') + self.add_generation_prompt = kwargs.get('add_generation_prompt', True) + self.stop_token_ids = kwargs.get('stop_token_ids') + + # Process stop tokens as in the original + self.stop_token_ids = ( + set(self.stop_token_ids) if self.stop_token_ids is not None else None + ) + + # Create environment with loopcontrols extension + from jinja2.ext import loopcontrols + import jinja2 + + self._environment = jinja2.sandbox.ImmutableSandboxedEnvironment( + loader=jinja2.BaseLoader(), + trim_blocks=True, + lstrip_blocks=True, + extensions=[loopcontrols] + ).from_string(self.template) + + # Replace the original __init__ with our patched version + Formatter.__init__ = patched_init + Formatter._is_patched = True + # Set the flag to indicate that the patch has been applied lib.Llama._is_patched = True