Fix jinja2 error while loading c4ai-command-a-03-2025

This commit is contained in:
oobabooga 2025-03-14 10:59:05 -07:00
parent f04a37adc2
commit 26317a4c7e
2 changed files with 46 additions and 1 deletions

View file

@ -11,6 +11,7 @@ from pathlib import Path
import gradio as gr
import yaml
from jinja2.ext import loopcontrols
from jinja2.sandbox import ImmutableSandboxedEnvironment
from PIL import Image
@ -35,7 +36,11 @@ def strftime_now(format):
return datetime.now().strftime(format)
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
jinja_env = ImmutableSandboxedEnvironment(
trim_blocks=True,
lstrip_blocks=True,
extensions=[loopcontrols]
)
jinja_env.globals["strftime_now"] = strftime_now

View file

@ -121,5 +121,45 @@ def monkey_patch_llama_cpp_python(lib):
lib.Llama.original_generate = lib.Llama.generate
lib.Llama.generate = my_generate
# Also patch Jinja2ChatFormatter to handle loop controls
if hasattr(lib, 'llama_chat_format') and hasattr(lib.llama_chat_format, 'Jinja2ChatFormatter'):
Formatter = lib.llama_chat_format.Jinja2ChatFormatter
if not getattr(Formatter, '_is_patched', False):
def patched_init(self, *args, **kwargs):
# Extract parameters from args or kwargs
if args:
self.template = args[0]
self.eos_token = args[1] if len(args) > 1 else kwargs.get('eos_token')
self.bos_token = args[2] if len(args) > 2 else kwargs.get('bos_token')
self.add_generation_prompt = args[3] if len(args) > 3 else kwargs.get('add_generation_prompt', True)
self.stop_token_ids = args[4] if len(args) > 4 else kwargs.get('stop_token_ids')
else:
self.template = kwargs.get('template')
self.eos_token = kwargs.get('eos_token')
self.bos_token = kwargs.get('bos_token')
self.add_generation_prompt = kwargs.get('add_generation_prompt', True)
self.stop_token_ids = kwargs.get('stop_token_ids')
# Process stop tokens as in the original
self.stop_token_ids = (
set(self.stop_token_ids) if self.stop_token_ids is not None else None
)
# Create environment with loopcontrols extension
from jinja2.ext import loopcontrols
import jinja2
self._environment = jinja2.sandbox.ImmutableSandboxedEnvironment(
loader=jinja2.BaseLoader(),
trim_blocks=True,
lstrip_blocks=True,
extensions=[loopcontrols]
).from_string(self.template)
# Replace the original __init__ with our patched version
Formatter.__init__ = patched_init
Formatter._is_patched = True
# Set the flag to indicate that the patch has been applied
lib.Llama._is_patched = True